From 70a6a7e0e163771e87dcef00e57c4ef8eb16ec97 Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Fri, 10 May 2019 15:23:25 +0100 Subject: [PATCH 01/60] Implement schema for run query filters --- faculty/clients/experiment.py | 122 ++++++++++++++++++++++++++++++- tests/clients/test_experiment.py | 42 +++++++++++ 2 files changed, 163 insertions(+), 1 deletion(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 1c76b2f1..b98d59e0 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -14,8 +14,9 @@ from collections import namedtuple from enum import Enum +import uuid -from marshmallow import fields, post_load +from marshmallow import fields, post_load, utils as marshmallow_utils from marshmallow_enum import EnumField from faculty.clients.base import BaseClient, BaseSchema, Conflict @@ -43,6 +44,12 @@ def __init__(self, message, experiment_id): self.experiment_id = experiment_id +class RunQueryFilterValidation(Exception): + def __init__(self, message, value): + super(RunQueryFilterValidation, self).__init__(message) + self.value = value + + class ExperimentRunStatus(Enum): RUNNING = "running" FINISHED = "finished" @@ -99,6 +106,119 @@ class ExperimentRunStatus(Enum): "RestoreExperimentRunsResponse", ["restored_run_ids", "conflicted_run_ids"] ) +SingleFilter = namedtuple("SingleFilter", ["by", "key", "operator", "value"]) + +CompoundFilter = namedtuple("CompoundFilter", ["operator", "conditions"]) + + +class SingleFilterOperator(Enum): + DEFINED = "defined" + EQUAL_TO = "eq" + NOT_EQUAL_TO = "ne" + LESS_THAN = "lt" + LESS_THAN_OR_EQUAL_TO = "le" + GREATER_THAN = "gt" + GREATER_THAN_OR_EQUAL_TO = "ge" + + +class SingleFilterBy(Enum): + PROJECT_ID = "projectId" + EXPERIMENT_ID = "experimentId" + RUN_ID = "runId" + DELETED_AT = "deletedAt" + TAG = "tag" + PARAM = "param" + METRIC = "metric" + + +class CompoundFilterOperator(Enum): + AND = "and" + OR = "or" + + +class SingleFilterValueField(fields.Field): + """ + Field that serialises/deserialises a run filter. + """ + + def _is_valid_uuid(self, value, obj): + return ( + isinstance(value, uuid.UUID) + and ( + obj.by == SingleFilterBy.PROJECT_ID + or obj.by == SingleFilterBy.RUN_ID + ) + ) + + def _is_valid_experiment_id(self, value, obj): + return ( + isinstance(value, int) + and obj.by == SingleFilterBy.EXPERIMENT_ID + ) + + def _is_directly_stringifiable(self, value, obj): + return ( + self._is_valid_uuid(value, obj) + or self._is_valid_experiment_id(value, obj) + or obj.by == SingleFilterBy.TAG + or obj.by == SingleFilterBy.PARAM + or obj.by == SingleFilterBy.METRIC + ) + + def _deserialize(self, value, attr, obj, **kwargs): + pass + + def _serialize(self, value, attr, obj, **kwargs): + if self._is_directly_stringifiable(value, obj): + return str(value) + elif obj.by == SingleFilterBy.DELETED_AT: + return marshmallow_utils.from_iso_datetime(str(value)) + else: + raise RunQueryFilterValidation( + "Validation error serialising run query filter", + value + ) + + +class FilterField(fields.Field): + """ + Field that serialises/deserialises a run filter. + """ + + def _deserialize(self, value, attr, obj, **kwargs): + if value is None: + return None + elif isinstance(value, SingleFilter): + return SingleFilterSchema().load(value) + else: + return CompoundFilterSchema().load(value) + + def _serialize(self, value, attr, obj, **kwargs): + if value is None: + return None + if isinstance(value, SingleFilter): + return SingleFilterSchema().dump(value) + else: + return CompoundFilterSchema().dump(value) + + +class SingleFilterSchema(BaseSchema): + by = EnumField(SingleFilterBy, by_value=True, required=True) + operator = EnumField(SingleFilterOperator, by_value=True, required=True) + value = SingleFilterValueField(required=True) + + +class CompoundFilterSchema(BaseSchema): + operator = EnumField( + CompoundFilterOperator, by_value=True, required=True) + conditions = fields.List(FilterField()) + + +class QueryRunsSchema(BaseSchema): + filter = FilterField(required=True) + sort = fields.String() + page = fields.String() + class MetricSchema(BaseSchema): key = fields.String(required=True) diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index 7b5455f0..d2aad3d7 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -23,6 +23,8 @@ from faculty.clients.base import Conflict from faculty.clients.experiment import ( CreateRunSchema, + CompoundFilter, + CompoundFilterOperator, DeleteExperimentRunsResponse, DeleteExperimentRunsResponseSchema, Experiment, @@ -47,8 +49,12 @@ Param, ParamConflict, ParamSchema, + QueryRunsSchema, RestoreExperimentRunsResponse, RestoreExperimentRunsResponseSchema, + SingleFilter, + SingleFilterBy, + SingleFilterOperator, Tag, TagSchema, ) @@ -227,6 +233,42 @@ ], } +# TODO: testing query stuff + + +def test_query_runs_schema(): + queryRunsObj = { + "filter": CompoundFilter( + operator=CompoundFilterOperator.AND, + conditions=[ + SingleFilter( + SingleFilterBy.EXPERIMENT_ID, + None, SingleFilterOperator.EQUAL_TO, 7), + None + ] + ), + "sort": "sort", + "page": "page" + } + expected_json = { + "filter": { + "operator": "and", + "conditions": [ + { + "by": "experimentId", + "operator": "eq", + "value": '7' + }, + None + ] + }, + "sort": "sort", + "page": "page" + } + data = QueryRunsSchema().dump(queryRunsObj) + assert data == expected_json +# TODO: testing query stuff + def test_experiment_schema(): data = ExperimentSchema().load(EXPERIMENT_BODY) From 27930c18eddcdd0eed111e76ad51a5add87b7a6d Mon Sep 17 00:00:00 2001 From: Hailey Fong Date: Mon, 13 May 2019 11:24:36 +0100 Subject: [PATCH 02/60] Add sort and page fields serialisation and corresponding test --- faculty/clients/experiment.py | 46 ++++++++++++++++++++++++-------- tests/clients/test_experiment.py | 24 +++++++++-------- 2 files changed, 48 insertions(+), 22 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index b98d59e0..700c55d9 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -204,6 +204,7 @@ def _serialize(self, value, attr, obj, **kwargs): class SingleFilterSchema(BaseSchema): by = EnumField(SingleFilterBy, by_value=True, required=True) + key = fields.String() operator = EnumField(SingleFilterOperator, by_value=True, required=True) value = SingleFilterValueField(required=True) @@ -214,10 +215,42 @@ class CompoundFilterSchema(BaseSchema): conditions = fields.List(FilterField()) +Sort = namedtuple("Sort",["by", "key", "order"]) + + +class SortBy(Enum): + STARTED_AT = "startedAt" + RUN_NUMBER = "runNumber" + DURATION = "duration" + TAG = "tag" + PARAM = "param" + METRIC = "metric" + + +class SortOrder(Enum): + ASC = "asc" + DESC = "desc" + + +class SortSchema(BaseSchema): + by = EnumField(SortBy, by_value=True, required=True) + key = fields.String() + order = EnumField(SortOrder, by_value=True, required=True) + + +class PageSchema(BaseSchema): + start = fields.Integer(required=True) + limit = fields.Integer(required=True) + + @post_load + def make_page(self, data): + return Page(**data) + + class QueryRunsSchema(BaseSchema): filter = FilterField(required=True) - sort = fields.String() - page = fields.String() + sort = fields.List(fields.Nested(SortSchema)) + page = fields.Nested(PageSchema, missing=None) class MetricSchema(BaseSchema): @@ -302,15 +335,6 @@ class ExperimentRunInfoSchema(BaseSchema): ended_at = fields.DateTime(data_key="endedAt", missing=None) -class PageSchema(BaseSchema): - start = fields.Integer(required=True) - limit = fields.Integer(required=True) - - @post_load - def make_page(self, data): - return Page(**data) - - class PaginationSchema(BaseSchema): start = fields.Integer(required=True) size = fields.Integer(required=True) diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index d2aad3d7..98b0f4ce 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -55,6 +55,9 @@ SingleFilter, SingleFilterBy, SingleFilterOperator, + Sort, + SortBy, + SortOrder, Tag, TagSchema, ) @@ -234,42 +237,41 @@ } # TODO: testing query stuff - - def test_query_runs_schema(): queryRunsObj = { "filter": CompoundFilter( operator=CompoundFilterOperator.AND, conditions=[ SingleFilter( - SingleFilterBy.EXPERIMENT_ID, - None, SingleFilterOperator.EQUAL_TO, 7), + SingleFilterBy.TAG, + "tag_key", SingleFilterOperator.EQUAL_TO, "tag_value"), None ] ), - "sort": "sort", - "page": "page" + "sort": [Sort(SortBy.PARAM, "param_key", SortOrder.ASC), \ + Sort(SortBy.RUN_NUMBER, None, SortOrder.DESC)], + "page": PAGE } expected_json = { "filter": { "operator": "and", "conditions": [ { - "by": "experimentId", + "by": "tag", + "key": "tag_key", "operator": "eq", - "value": '7' + "value": "tag_value" }, None ] }, - "sort": "sort", - "page": "page" + "sort": [{"by": "param", "key": "param_key", "order": "asc"}, {"by": "runNumber", "key": None, "order": "desc"}], + "page": PAGE_BODY } data = QueryRunsSchema().dump(queryRunsObj) assert data == expected_json # TODO: testing query stuff - def test_experiment_schema(): data = ExperimentSchema().load(EXPERIMENT_BODY) assert data == EXPERIMENT From 0b2033b7c4ce8f4dc93a49db2147cf2747aab701 Mon Sep 17 00:00:00 2001 From: Hailey Fong Date: Mon, 13 May 2019 12:41:33 +0100 Subject: [PATCH 03/60] Add QueryRun named tuple --- faculty/clients/experiment.py | 1 + tests/clients/test_experiment.py | 15 ++++++++------- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 700c55d9..2ace5683 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -246,6 +246,7 @@ class PageSchema(BaseSchema): def make_page(self, data): return Page(**data) +QueryRuns = namedtuple("QueryRuns", ["filter", "sort", "page"]) class QueryRunsSchema(BaseSchema): filter = FilterField(required=True) diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index 98b0f4ce..552ce2a5 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -49,6 +49,7 @@ Param, ParamConflict, ParamSchema, + QueryRuns, QueryRunsSchema, RestoreExperimentRunsResponse, RestoreExperimentRunsResponseSchema, @@ -236,10 +237,10 @@ ], } -# TODO: testing query stuff + def test_query_runs_schema(): - queryRunsObj = { - "filter": CompoundFilter( + queryRunsObj = QueryRuns( + CompoundFilter( operator=CompoundFilterOperator.AND, conditions=[ SingleFilter( @@ -248,10 +249,10 @@ def test_query_runs_schema(): None ] ), - "sort": [Sort(SortBy.PARAM, "param_key", SortOrder.ASC), \ + [Sort(SortBy.PARAM, "param_key", SortOrder.ASC), \ Sort(SortBy.RUN_NUMBER, None, SortOrder.DESC)], - "page": PAGE - } + PAGE + ) expected_json = { "filter": { "operator": "and", @@ -270,7 +271,7 @@ def test_query_runs_schema(): } data = QueryRunsSchema().dump(queryRunsObj) assert data == expected_json -# TODO: testing query stuff + def test_experiment_schema(): data = ExperimentSchema().load(EXPERIMENT_BODY) From 52ad53e6ee0a0ca385ea3667faeadae3b798d67b Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Mon, 13 May 2019 12:02:38 +0100 Subject: [PATCH 04/60] Reorder schemas --- faculty/clients/experiment.py | 184 +++++++++++++++++----------------- 1 file changed, 92 insertions(+), 92 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 2ace5683..abe66357 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -136,86 +136,7 @@ class CompoundFilterOperator(Enum): OR = "or" -class SingleFilterValueField(fields.Field): - """ - Field that serialises/deserialises a run filter. - """ - - def _is_valid_uuid(self, value, obj): - return ( - isinstance(value, uuid.UUID) - and ( - obj.by == SingleFilterBy.PROJECT_ID - or obj.by == SingleFilterBy.RUN_ID - ) - ) - - def _is_valid_experiment_id(self, value, obj): - return ( - isinstance(value, int) - and obj.by == SingleFilterBy.EXPERIMENT_ID - ) - - def _is_directly_stringifiable(self, value, obj): - return ( - self._is_valid_uuid(value, obj) - or self._is_valid_experiment_id(value, obj) - or obj.by == SingleFilterBy.TAG - or obj.by == SingleFilterBy.PARAM - or obj.by == SingleFilterBy.METRIC - ) - - def _deserialize(self, value, attr, obj, **kwargs): - pass - - def _serialize(self, value, attr, obj, **kwargs): - if self._is_directly_stringifiable(value, obj): - return str(value) - elif obj.by == SingleFilterBy.DELETED_AT: - return marshmallow_utils.from_iso_datetime(str(value)) - else: - raise RunQueryFilterValidation( - "Validation error serialising run query filter", - value - ) - - -class FilterField(fields.Field): - """ - Field that serialises/deserialises a run filter. - """ - - def _deserialize(self, value, attr, obj, **kwargs): - if value is None: - return None - elif isinstance(value, SingleFilter): - return SingleFilterSchema().load(value) - else: - return CompoundFilterSchema().load(value) - - def _serialize(self, value, attr, obj, **kwargs): - if value is None: - return None - if isinstance(value, SingleFilter): - return SingleFilterSchema().dump(value) - else: - return CompoundFilterSchema().dump(value) - - -class SingleFilterSchema(BaseSchema): - by = EnumField(SingleFilterBy, by_value=True, required=True) - key = fields.String() - operator = EnumField(SingleFilterOperator, by_value=True, required=True) - value = SingleFilterValueField(required=True) - - -class CompoundFilterSchema(BaseSchema): - operator = EnumField( - CompoundFilterOperator, by_value=True, required=True) - conditions = fields.List(FilterField()) - - -Sort = namedtuple("Sort",["by", "key", "order"]) +Sort = namedtuple("Sort", ["by", "key", "order"]) class SortBy(Enum): @@ -232,12 +153,6 @@ class SortOrder(Enum): DESC = "desc" -class SortSchema(BaseSchema): - by = EnumField(SortBy, by_value=True, required=True) - key = fields.String() - order = EnumField(SortOrder, by_value=True, required=True) - - class PageSchema(BaseSchema): start = fields.Integer(required=True) limit = fields.Integer(required=True) @@ -248,12 +163,6 @@ def make_page(self, data): QueryRuns = namedtuple("QueryRuns", ["filter", "sort", "page"]) -class QueryRunsSchema(BaseSchema): - filter = FilterField(required=True) - sort = fields.List(fields.Nested(SortSchema)) - page = fields.Nested(PageSchema, missing=None) - - class MetricSchema(BaseSchema): key = fields.String(required=True) value = fields.Float(required=True) @@ -390,6 +299,97 @@ def make_restore_runs_response(self, data): return RestoreExperimentRunsResponse(**data) +class SingleFilterValueField(fields.Field): + """ + Field that serialises/deserialises a run filter. + """ + + def _is_valid_uuid(self, value, obj): + return ( + isinstance(value, uuid.UUID) + and ( + obj.by == SingleFilterBy.PROJECT_ID + or obj.by == SingleFilterBy.RUN_ID + ) + ) + + def _is_valid_experiment_id(self, value, obj): + return ( + isinstance(value, int) + and obj.by == SingleFilterBy.EXPERIMENT_ID + ) + + def _is_directly_stringifiable(self, value, obj): + return ( + self._is_valid_uuid(value, obj) + or self._is_valid_experiment_id(value, obj) + or obj.by == SingleFilterBy.TAG + or obj.by == SingleFilterBy.PARAM + or obj.by == SingleFilterBy.METRIC + ) + + def _deserialize(self, value, attr, obj, **kwargs): + pass + + def _serialize(self, value, attr, obj, **kwargs): + if self._is_directly_stringifiable(value, obj): + return str(value) + elif obj.by == SingleFilterBy.DELETED_AT: + return marshmallow_utils.from_iso_datetime(str(value)) + else: + raise RunQueryFilterValidation( + "Validation error serialising run query filter", + value + ) + + +class FilterField(fields.Field): + """ + Field that serialises/deserialises a run filter. + """ + + def _deserialize(self, value, attr, obj, **kwargs): + if value is None: + return None + elif isinstance(value, SingleFilter): + return SingleFilterSchema().load(value) + else: + return CompoundFilterSchema().load(value) + + def _serialize(self, value, attr, obj, **kwargs): + if value is None: + return None + if isinstance(value, SingleFilter): + return SingleFilterSchema().dump(value) + else: + return CompoundFilterSchema().dump(value) + + +class SingleFilterSchema(BaseSchema): + by = EnumField(SingleFilterBy, by_value=True, required=True) + key = fields.String() + operator = EnumField(SingleFilterOperator, by_value=True, required=True) + value = SingleFilterValueField(required=True) + + +class CompoundFilterSchema(BaseSchema): + operator = EnumField( + CompoundFilterOperator, by_value=True, required=True) + conditions = fields.List(FilterField()) + + +class SortSchema(BaseSchema): + by = EnumField(SortBy, by_value=True, required=True) + key = fields.String() + order = EnumField(SortOrder, by_value=True, required=True) + + +class QueryRunsSchema(BaseSchema): + filter = FilterField(required=True) + sort = fields.List(fields.Nested(SortSchema)) + page = fields.Nested(PageSchema, missing=None) + + class MetricHistorySchema(BaseSchema): history = fields.Nested(MetricSchema, many=True, required=True) From 7e9c221e7e07fadbdff673475b9e4274b46d7afd Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Mon, 13 May 2019 12:44:05 +0100 Subject: [PATCH 05/60] Rearrange query runs --- faculty/clients/experiment.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index abe66357..08241644 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -153,6 +153,9 @@ class SortOrder(Enum): DESC = "desc" +QueryRuns = namedtuple("QueryRuns", ["filter", "sort", "page"]) + + class PageSchema(BaseSchema): start = fields.Integer(required=True) limit = fields.Integer(required=True) @@ -161,7 +164,6 @@ class PageSchema(BaseSchema): def make_page(self, data): return Page(**data) -QueryRuns = namedtuple("QueryRuns", ["filter", "sort", "page"]) class MetricSchema(BaseSchema): key = fields.String(required=True) From c740eee5b29238d0777504cd5dd252fdc1e9a7b6 Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Mon, 13 May 2019 14:34:54 +0100 Subject: [PATCH 06/60] Parametrise tests and test project id filter --- faculty/clients/experiment.py | 1 + tests/clients/test_experiment.py | 119 ++++++++++++++++++++++--------- 2 files changed, 85 insertions(+), 35 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 08241644..b5e3de70 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -339,6 +339,7 @@ def _serialize(self, value, attr, obj, **kwargs): elif obj.by == SingleFilterBy.DELETED_AT: return marshmallow_utils.from_iso_datetime(str(value)) else: + print(value) raise RunQueryFilterValidation( "Validation error serialising run query filter", value diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index 552ce2a5..bef258b2 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -238,41 +238,6 @@ } -def test_query_runs_schema(): - queryRunsObj = QueryRuns( - CompoundFilter( - operator=CompoundFilterOperator.AND, - conditions=[ - SingleFilter( - SingleFilterBy.TAG, - "tag_key", SingleFilterOperator.EQUAL_TO, "tag_value"), - None - ] - ), - [Sort(SortBy.PARAM, "param_key", SortOrder.ASC), \ - Sort(SortBy.RUN_NUMBER, None, SortOrder.DESC)], - PAGE - ) - expected_json = { - "filter": { - "operator": "and", - "conditions": [ - { - "by": "tag", - "key": "tag_key", - "operator": "eq", - "value": "tag_value" - }, - None - ] - }, - "sort": [{"by": "param", "key": "param_key", "order": "asc"}, {"by": "runNumber", "key": None, "order": "desc"}], - "page": PAGE_BODY - } - data = QueryRunsSchema().dump(queryRunsObj) - assert data == expected_json - - def test_experiment_schema(): data = ExperimentSchema().load(EXPERIMENT_BODY) assert data == EXPERIMENT @@ -374,6 +339,90 @@ def test_experiment_run_data_schema_multiple(): assert data == {"tags": [TAG_BODY, OTHER_TAG_BODY]} +# check that all three fields are nullable +# check all types of filter +# check all types of sort + +PROJECT_ID_FILTER = SingleFilter( + by=SingleFilterBy.PROJECT_ID, key="k", operator=SingleFilterOperator.EQUAL_TO, value=PROJECT_ID) +PROJECT_ID_FILTER_BODY = { + "by": "projectId", + "key": "k", + "operator": "eq", + "value": str(PROJECT_ID) +} + +TAG_FILTER = SingleFilter( + by=SingleFilterBy.TAG, key="tag_key", operator=SingleFilterOperator.EQUAL_TO, value="tag_value") +TAG_FILTER_BODY = { + "by": "tag", + "key": "tag_key", + "operator": "eq", + "value": "tag_value" +} + +AND_FILTER = CompoundFilter( + operator=CompoundFilterOperator.AND, + conditions=[ + SingleFilter( + SingleFilterBy.TAG, + "tag_key", SingleFilterOperator.EQUAL_TO, "tag_value"), + None + ] +) +AND_FILTER_BODY = { + "operator": "and", + "conditions": [ + { + "by": "tag", + "key": "tag_key", + "operator": "eq", + "value": "tag_value" + }, + None + ] +} + +MULTI_SORT = [ + Sort(SortBy.PARAM, "param_key", SortOrder.ASC), + Sort(SortBy.RUN_NUMBER, None, SortOrder.DESC) +] +MULTI_SORT_BODY = [ + {"by": "param", "key": "param_key", "order": "asc"}, + {"by": "runNumber", "key": None, "order": "desc"}] + + +@pytest.mark.parametrize( + "pfilter,pfilter_body", + [ + [None, None], + [PROJECT_ID_FILTER, PROJECT_ID_FILTER_BODY], + [TAG_FILTER, TAG_FILTER_BODY], + [AND_FILTER, AND_FILTER_BODY] + ] +) +@pytest.mark.parametrize( + "psort,psort_body", + [ + [None, None], + [MULTI_SORT, MULTI_SORT_BODY] + ] +) +def test_query_runs_schema(mocker, pfilter, psort, pfilter_body, psort_body): + queryRunsObj = { + "filter": pfilter, + "sort": psort, + "page": PAGE + } + expected_json = { + "filter": pfilter_body, + "sort": psort_body, + "page": PAGE_BODY + } + data = QueryRunsSchema().dump(queryRunsObj) + assert data == expected_json + + @pytest.mark.parametrize("description", [None, "experiment description"]) @pytest.mark.parametrize("artifact_location", [None, "s3://mybucket"]) def test_experiment_client_create(mocker, description, artifact_location): From b04bc42bee6d090fe8fd314d94e8f82b07253cbd Mon Sep 17 00:00:00 2001 From: Hailey Fong Date: Mon, 13 May 2019 15:07:53 +0100 Subject: [PATCH 07/60] Complete parameterising RunQuery tests for all different filters --- faculty/clients/experiment.py | 2 +- tests/clients/test_experiment.py | 55 ++++++++++++++++++++++++++++++-- 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index b5e3de70..81bb02c8 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -337,7 +337,7 @@ def _serialize(self, value, attr, obj, **kwargs): if self._is_directly_stringifiable(value, obj): return str(value) elif obj.by == SingleFilterBy.DELETED_AT: - return marshmallow_utils.from_iso_datetime(str(value)) + return marshmallow_utils.from_iso_datetime(str(value)).isoformat() else: print(value) raise RunQueryFilterValidation( diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index bef258b2..c107f774 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -75,6 +75,7 @@ LAST_UPDATED_AT_STRING = "2018-03-10T11:32:30.172Z" DELETED_AT = datetime(2018, 3, 10, 11, 37, 42, 482000, tzinfo=UTC) DELETED_AT_STRING = "2018-03-10T11:37:42.482Z" +DELETED_AT_STRING_PYTHON = "2018-03-10T11:37:42.482000+00:00" EXPERIMENT = Experiment( id=EXPERIMENT_ID, @@ -344,14 +345,41 @@ def test_experiment_run_data_schema_multiple(): # check all types of sort PROJECT_ID_FILTER = SingleFilter( - by=SingleFilterBy.PROJECT_ID, key="k", operator=SingleFilterOperator.EQUAL_TO, value=PROJECT_ID) + by=SingleFilterBy.PROJECT_ID, key=None, operator=SingleFilterOperator.EQUAL_TO, value=PROJECT_ID) PROJECT_ID_FILTER_BODY = { "by": "projectId", - "key": "k", + "key": None, "operator": "eq", "value": str(PROJECT_ID) } +EXPERIMENT_ID_FILTER = SingleFilter( + by=SingleFilterBy.EXPERIMENT_ID, key=None, operator=SingleFilterOperator.EQUAL_TO, value=1) +EXPERIMENT_ID_BODY = { + "by": "experimentId", + "key": None, + "operator": "eq", + "value": "1" +} + +RUN_ID_FILTER = SingleFilter( + by=SingleFilterBy.RUN_ID, key=None, operator=SingleFilterOperator.EQUAL_TO, value=EXPERIMENT_RUN_ID) +RUN_ID_BODY = { + "by": "runId", + "key": None, + "operator": "eq", + "value": str(EXPERIMENT_RUN_ID) +} + +DELETED_AT_FILTER = SingleFilter( + by=SingleFilterBy.DELETED_AT, key=None, operator=SingleFilterOperator.EQUAL_TO, value=DELETED_AT) +DELETED_AT_BODY = { + "by": "deletedAt", + "key": None, + "operator": "eq", + "value": DELETED_AT_STRING_PYTHON +} + TAG_FILTER = SingleFilter( by=SingleFilterBy.TAG, key="tag_key", operator=SingleFilterOperator.EQUAL_TO, value="tag_value") TAG_FILTER_BODY = { @@ -361,6 +389,24 @@ def test_experiment_run_data_schema_multiple(): "value": "tag_value" } +PARAM_FILTER = SingleFilter( + by=SingleFilterBy.PARAM, key="param_key", operator=SingleFilterOperator.EQUAL_TO, value="param_value") +PARAM_FILTER_BODY = { + "by": "param", + "key": "param_key", + "operator": "eq", + "value": "param_value" +} + +METRIC_FILTER = SingleFilter( + by=SingleFilterBy.METRIC, key="metric_key", operator=SingleFilterOperator.EQUAL_TO, value="metric_value") +METRIC_FILTER_BODY = { + "by": "metric", + "key": "metric_key", + "operator": "eq", + "value": "metric_value" +} + AND_FILTER = CompoundFilter( operator=CompoundFilterOperator.AND, conditions=[ @@ -397,7 +443,12 @@ def test_experiment_run_data_schema_multiple(): [ [None, None], [PROJECT_ID_FILTER, PROJECT_ID_FILTER_BODY], + [EXPERIMENT_ID_FILTER, EXPERIMENT_ID_BODY], + [RUN_ID_FILTER, RUN_ID_BODY], + [DELETED_AT_FILTER, DELETED_AT_BODY], [TAG_FILTER, TAG_FILTER_BODY], + [PARAM_FILTER, PARAM_FILTER_BODY], + [METRIC_FILTER, METRIC_FILTER_BODY], [AND_FILTER, AND_FILTER_BODY] ] ) From e98fdf7db179ae84e8c96ea178e5f516f4de4721 Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Mon, 13 May 2019 15:25:18 +0100 Subject: [PATCH 08/60] Add sort test cases --- tests/clients/test_experiment.py | 48 +++++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index c107f774..17cf94e6 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -429,6 +429,46 @@ def test_experiment_run_data_schema_multiple(): ] } +OR_FILTER = CompoundFilter( + operator=CompoundFilterOperator.OR, + conditions=[ + SingleFilter( + SingleFilterBy.TAG, + "tag_key", SingleFilterOperator.EQUAL_TO, "tag_value"), + None + ] +) +OR_FILTER_BODY = { + "operator": "or", + "conditions": [ + { + "by": "tag", + "key": "tag_key", + "operator": "eq", + "value": "tag_value" + }, + None + ] +} + +RUN_NUMBER_SORT = [Sort(SortBy.RUN_NUMBER, None, SortOrder.ASC)] +RUN_NUMBER_SORT_BODY = [{"by": "runNumber", "key": None, "order": "asc"}] + +STARTED_AT_SORT = [Sort(SortBy.STARTED_AT, None, SortOrder.DESC)] +STARTED_AT_SORT_BODY = [{"by": "startedAt", "key": None, "order": "desc"}] + +DURATION_SORT = [Sort(SortBy.DURATION, None, SortOrder.DESC)] +DURATION_SORT_BODY = [{"by": "duration", "key": None, "order": "desc"}] + +PARAM_SORT = [Sort(SortBy.PARAM, "param_key", SortOrder.DESC)] +PARAM_SORT_BODY = [{"by": "param", "key": "param_key", "order": "desc"}] + +TAG_SORT = [Sort(SortBy.TAG, "tag_key", SortOrder.DESC)] +TAG_SORT_BODY = [{"by": "tag", "key": "tag_key", "order": "desc"}] + +METRIC_SORT = [Sort(SortBy.METRIC, "metric_key", SortOrder.DESC)] +METRIC_SORT_BODY = [{"by": "metric", "key": "metric_key", "order": "desc"}] + MULTI_SORT = [ Sort(SortBy.PARAM, "param_key", SortOrder.ASC), Sort(SortBy.RUN_NUMBER, None, SortOrder.DESC) @@ -448,7 +488,7 @@ def test_experiment_run_data_schema_multiple(): [DELETED_AT_FILTER, DELETED_AT_BODY], [TAG_FILTER, TAG_FILTER_BODY], [PARAM_FILTER, PARAM_FILTER_BODY], - [METRIC_FILTER, METRIC_FILTER_BODY], + [METRIC_FILTER, METRIC_FILTER_BODY], [AND_FILTER, AND_FILTER_BODY] ] ) @@ -456,6 +496,12 @@ def test_experiment_run_data_schema_multiple(): "psort,psort_body", [ [None, None], + [RUN_NUMBER_SORT, RUN_NUMBER_SORT_BODY], + [STARTED_AT_SORT, STARTED_AT_SORT_BODY], + [DURATION_SORT, DURATION_SORT_BODY], + [PARAM_SORT, PARAM_SORT_BODY], + [TAG_SORT, TAG_SORT_BODY], + [METRIC_SORT, METRIC_SORT_BODY], [MULTI_SORT, MULTI_SORT_BODY] ] ) From 0625d1bda535bdd7c2de14de2ce04fb4dbf87800 Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Mon, 13 May 2019 15:37:30 +0100 Subject: [PATCH 09/60] Fix formatting --- faculty/clients/experiment.py | 18 ++---- tests/clients/test_experiment.py | 95 +++++++++++++++++++------------- 2 files changed, 63 insertions(+), 50 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 81bb02c8..5c2a3c99 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -307,18 +307,14 @@ class SingleFilterValueField(fields.Field): """ def _is_valid_uuid(self, value, obj): - return ( - isinstance(value, uuid.UUID) - and ( - obj.by == SingleFilterBy.PROJECT_ID - or obj.by == SingleFilterBy.RUN_ID - ) + return isinstance(value, uuid.UUID) and ( + obj.by == SingleFilterBy.PROJECT_ID + or obj.by == SingleFilterBy.RUN_ID ) def _is_valid_experiment_id(self, value, obj): return ( - isinstance(value, int) - and obj.by == SingleFilterBy.EXPERIMENT_ID + isinstance(value, int) and obj.by == SingleFilterBy.EXPERIMENT_ID ) def _is_directly_stringifiable(self, value, obj): @@ -341,8 +337,7 @@ def _serialize(self, value, attr, obj, **kwargs): else: print(value) raise RunQueryFilterValidation( - "Validation error serialising run query filter", - value + "Validation error serialising run query filter", value ) @@ -376,8 +371,7 @@ class SingleFilterSchema(BaseSchema): class CompoundFilterSchema(BaseSchema): - operator = EnumField( - CompoundFilterOperator, by_value=True, required=True) + operator = EnumField(CompoundFilterOperator, by_value=True, required=True) conditions = fields.List(FilterField()) diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index 17cf94e6..fecccf6b 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -345,66 +345,82 @@ def test_experiment_run_data_schema_multiple(): # check all types of sort PROJECT_ID_FILTER = SingleFilter( - by=SingleFilterBy.PROJECT_ID, key=None, operator=SingleFilterOperator.EQUAL_TO, value=PROJECT_ID) + SingleFilterBy.PROJECT_ID, None, SingleFilterOperator.EQUAL_TO, PROJECT_ID +) PROJECT_ID_FILTER_BODY = { "by": "projectId", "key": None, "operator": "eq", - "value": str(PROJECT_ID) + "value": str(PROJECT_ID), } EXPERIMENT_ID_FILTER = SingleFilter( - by=SingleFilterBy.EXPERIMENT_ID, key=None, operator=SingleFilterOperator.EQUAL_TO, value=1) + SingleFilterBy.EXPERIMENT_ID, None, SingleFilterOperator.EQUAL_TO, 1 +) EXPERIMENT_ID_BODY = { "by": "experimentId", "key": None, "operator": "eq", - "value": "1" + "value": "1", } RUN_ID_FILTER = SingleFilter( - by=SingleFilterBy.RUN_ID, key=None, operator=SingleFilterOperator.EQUAL_TO, value=EXPERIMENT_RUN_ID) + SingleFilterBy.RUN_ID, + None, + SingleFilterOperator.EQUAL_TO, + EXPERIMENT_RUN_ID, +) RUN_ID_BODY = { "by": "runId", "key": None, "operator": "eq", - "value": str(EXPERIMENT_RUN_ID) + "value": str(EXPERIMENT_RUN_ID), } DELETED_AT_FILTER = SingleFilter( - by=SingleFilterBy.DELETED_AT, key=None, operator=SingleFilterOperator.EQUAL_TO, value=DELETED_AT) + SingleFilterBy.DELETED_AT, None, SingleFilterOperator.EQUAL_TO, DELETED_AT +) DELETED_AT_BODY = { "by": "deletedAt", "key": None, "operator": "eq", - "value": DELETED_AT_STRING_PYTHON + "value": DELETED_AT_STRING_PYTHON, } TAG_FILTER = SingleFilter( - by=SingleFilterBy.TAG, key="tag_key", operator=SingleFilterOperator.EQUAL_TO, value="tag_value") + SingleFilterBy.TAG, "tag_key", SingleFilterOperator.EQUAL_TO, "tag_value" +) TAG_FILTER_BODY = { "by": "tag", "key": "tag_key", "operator": "eq", - "value": "tag_value" + "value": "tag_value", } PARAM_FILTER = SingleFilter( - by=SingleFilterBy.PARAM, key="param_key", operator=SingleFilterOperator.EQUAL_TO, value="param_value") + SingleFilterBy.PARAM, + "param_key", + SingleFilterOperator.EQUAL_TO, + "param_value", +) PARAM_FILTER_BODY = { "by": "param", "key": "param_key", "operator": "eq", - "value": "param_value" + "value": "param_value", } METRIC_FILTER = SingleFilter( - by=SingleFilterBy.METRIC, key="metric_key", operator=SingleFilterOperator.EQUAL_TO, value="metric_value") + SingleFilterBy.METRIC, + "metric_key", + SingleFilterOperator.EQUAL_TO, + "metric_value", +) METRIC_FILTER_BODY = { "by": "metric", "key": "metric_key", "operator": "eq", - "value": "metric_value" + "value": "metric_value", } AND_FILTER = CompoundFilter( @@ -412,9 +428,12 @@ def test_experiment_run_data_schema_multiple(): conditions=[ SingleFilter( SingleFilterBy.TAG, - "tag_key", SingleFilterOperator.EQUAL_TO, "tag_value"), - None - ] + "tag_key", + SingleFilterOperator.EQUAL_TO, + "tag_value", + ), + None, + ], ) AND_FILTER_BODY = { "operator": "and", @@ -423,10 +442,10 @@ def test_experiment_run_data_schema_multiple(): "by": "tag", "key": "tag_key", "operator": "eq", - "value": "tag_value" + "value": "tag_value", }, - None - ] + None, + ], } OR_FILTER = CompoundFilter( @@ -434,9 +453,12 @@ def test_experiment_run_data_schema_multiple(): conditions=[ SingleFilter( SingleFilterBy.TAG, - "tag_key", SingleFilterOperator.EQUAL_TO, "tag_value"), - None - ] + "tag_key", + SingleFilterOperator.EQUAL_TO, + "tag_value", + ), + None, + ], ) OR_FILTER_BODY = { "operator": "or", @@ -445,10 +467,10 @@ def test_experiment_run_data_schema_multiple(): "by": "tag", "key": "tag_key", "operator": "eq", - "value": "tag_value" + "value": "tag_value", }, - None - ] + None, + ], } RUN_NUMBER_SORT = [Sort(SortBy.RUN_NUMBER, None, SortOrder.ASC)] @@ -471,11 +493,12 @@ def test_experiment_run_data_schema_multiple(): MULTI_SORT = [ Sort(SortBy.PARAM, "param_key", SortOrder.ASC), - Sort(SortBy.RUN_NUMBER, None, SortOrder.DESC) + Sort(SortBy.RUN_NUMBER, None, SortOrder.DESC), ] MULTI_SORT_BODY = [ {"by": "param", "key": "param_key", "order": "asc"}, - {"by": "runNumber", "key": None, "order": "desc"}] + {"by": "runNumber", "key": None, "order": "desc"}, +] @pytest.mark.parametrize( @@ -489,8 +512,8 @@ def test_experiment_run_data_schema_multiple(): [TAG_FILTER, TAG_FILTER_BODY], [PARAM_FILTER, PARAM_FILTER_BODY], [METRIC_FILTER, METRIC_FILTER_BODY], - [AND_FILTER, AND_FILTER_BODY] - ] + [AND_FILTER, AND_FILTER_BODY], + ], ) @pytest.mark.parametrize( "psort,psort_body", @@ -502,19 +525,15 @@ def test_experiment_run_data_schema_multiple(): [PARAM_SORT, PARAM_SORT_BODY], [TAG_SORT, TAG_SORT_BODY], [METRIC_SORT, METRIC_SORT_BODY], - [MULTI_SORT, MULTI_SORT_BODY] - ] + [MULTI_SORT, MULTI_SORT_BODY], + ], ) def test_query_runs_schema(mocker, pfilter, psort, pfilter_body, psort_body): - queryRunsObj = { - "filter": pfilter, - "sort": psort, - "page": PAGE - } + queryRunsObj = QueryRuns(pfilter, psort, PAGE) expected_json = { "filter": pfilter_body, "sort": psort_body, - "page": PAGE_BODY + "page": PAGE_BODY, } data = QueryRunsSchema().dump(queryRunsObj) assert data == expected_json From b1988b7b21e9457648156cef1780b9c9d20c1d21 Mon Sep 17 00:00:00 2001 From: Hailey Fong Date: Mon, 13 May 2019 17:14:05 +0100 Subject: [PATCH 10/60] Implement validation of RunQuery --- faculty/clients/experiment.py | 12 +- tests/clients/test_experiment.py | 393 ++++++++++++++++--------------- 2 files changed, 209 insertions(+), 196 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 5c2a3c99..43b0c83a 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -136,8 +136,16 @@ class CompoundFilterOperator(Enum): OR = "or" -Sort = namedtuple("Sort", ["by", "key", "order"]) - +_Sort = namedtuple("_Sort", ["by", "key", "order"]) + +class Sort(_Sort): + def __new__(self, by, key, order): + if by in [SortBy.STARTED_AT, SortBy.RUN_NUMBER, SortBy.DURATION] and key is not None: + raise ValueError("key must be none for type {}".format(by)) + elif by in [SortBy.TAG, SortBy.PARAM, SortBy.METRIC] and key is None: + raise ValueError("key must not be none for type {}".format(by)) + self = super(Sort, self).__init__(by, key, order) + return self class SortBy(Enum): STARTED_AT = "startedAt" diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index fecccf6b..b37b486e 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -344,200 +344,205 @@ def test_experiment_run_data_schema_multiple(): # check all types of filter # check all types of sort -PROJECT_ID_FILTER = SingleFilter( - SingleFilterBy.PROJECT_ID, None, SingleFilterOperator.EQUAL_TO, PROJECT_ID -) -PROJECT_ID_FILTER_BODY = { - "by": "projectId", - "key": None, - "operator": "eq", - "value": str(PROJECT_ID), -} - -EXPERIMENT_ID_FILTER = SingleFilter( - SingleFilterBy.EXPERIMENT_ID, None, SingleFilterOperator.EQUAL_TO, 1 -) -EXPERIMENT_ID_BODY = { - "by": "experimentId", - "key": None, - "operator": "eq", - "value": "1", -} - -RUN_ID_FILTER = SingleFilter( - SingleFilterBy.RUN_ID, - None, - SingleFilterOperator.EQUAL_TO, - EXPERIMENT_RUN_ID, -) -RUN_ID_BODY = { - "by": "runId", - "key": None, - "operator": "eq", - "value": str(EXPERIMENT_RUN_ID), -} - -DELETED_AT_FILTER = SingleFilter( - SingleFilterBy.DELETED_AT, None, SingleFilterOperator.EQUAL_TO, DELETED_AT -) -DELETED_AT_BODY = { - "by": "deletedAt", - "key": None, - "operator": "eq", - "value": DELETED_AT_STRING_PYTHON, -} - -TAG_FILTER = SingleFilter( - SingleFilterBy.TAG, "tag_key", SingleFilterOperator.EQUAL_TO, "tag_value" -) -TAG_FILTER_BODY = { - "by": "tag", - "key": "tag_key", - "operator": "eq", - "value": "tag_value", -} - -PARAM_FILTER = SingleFilter( - SingleFilterBy.PARAM, - "param_key", - SingleFilterOperator.EQUAL_TO, - "param_value", -) -PARAM_FILTER_BODY = { - "by": "param", - "key": "param_key", - "operator": "eq", - "value": "param_value", -} - -METRIC_FILTER = SingleFilter( - SingleFilterBy.METRIC, - "metric_key", - SingleFilterOperator.EQUAL_TO, - "metric_value", -) -METRIC_FILTER_BODY = { - "by": "metric", - "key": "metric_key", - "operator": "eq", - "value": "metric_value", -} - -AND_FILTER = CompoundFilter( - operator=CompoundFilterOperator.AND, - conditions=[ - SingleFilter( - SingleFilterBy.TAG, - "tag_key", - SingleFilterOperator.EQUAL_TO, - "tag_value", - ), - None, - ], -) -AND_FILTER_BODY = { - "operator": "and", - "conditions": [ - { - "by": "tag", - "key": "tag_key", - "operator": "eq", - "value": "tag_value", - }, - None, - ], -} - -OR_FILTER = CompoundFilter( - operator=CompoundFilterOperator.OR, - conditions=[ - SingleFilter( - SingleFilterBy.TAG, - "tag_key", - SingleFilterOperator.EQUAL_TO, - "tag_value", - ), - None, - ], -) -OR_FILTER_BODY = { - "operator": "or", - "conditions": [ - { - "by": "tag", - "key": "tag_key", - "operator": "eq", - "value": "tag_value", - }, - None, - ], -} - -RUN_NUMBER_SORT = [Sort(SortBy.RUN_NUMBER, None, SortOrder.ASC)] -RUN_NUMBER_SORT_BODY = [{"by": "runNumber", "key": None, "order": "asc"}] - -STARTED_AT_SORT = [Sort(SortBy.STARTED_AT, None, SortOrder.DESC)] -STARTED_AT_SORT_BODY = [{"by": "startedAt", "key": None, "order": "desc"}] - -DURATION_SORT = [Sort(SortBy.DURATION, None, SortOrder.DESC)] -DURATION_SORT_BODY = [{"by": "duration", "key": None, "order": "desc"}] - -PARAM_SORT = [Sort(SortBy.PARAM, "param_key", SortOrder.DESC)] -PARAM_SORT_BODY = [{"by": "param", "key": "param_key", "order": "desc"}] - -TAG_SORT = [Sort(SortBy.TAG, "tag_key", SortOrder.DESC)] -TAG_SORT_BODY = [{"by": "tag", "key": "tag_key", "order": "desc"}] - -METRIC_SORT = [Sort(SortBy.METRIC, "metric_key", SortOrder.DESC)] -METRIC_SORT_BODY = [{"by": "metric", "key": "metric_key", "order": "desc"}] - -MULTI_SORT = [ - Sort(SortBy.PARAM, "param_key", SortOrder.ASC), - Sort(SortBy.RUN_NUMBER, None, SortOrder.DESC), -] -MULTI_SORT_BODY = [ - {"by": "param", "key": "param_key", "order": "asc"}, - {"by": "runNumber", "key": None, "order": "desc"}, -] - - -@pytest.mark.parametrize( - "pfilter,pfilter_body", - [ - [None, None], - [PROJECT_ID_FILTER, PROJECT_ID_FILTER_BODY], - [EXPERIMENT_ID_FILTER, EXPERIMENT_ID_BODY], - [RUN_ID_FILTER, RUN_ID_BODY], - [DELETED_AT_FILTER, DELETED_AT_BODY], - [TAG_FILTER, TAG_FILTER_BODY], - [PARAM_FILTER, PARAM_FILTER_BODY], - [METRIC_FILTER, METRIC_FILTER_BODY], - [AND_FILTER, AND_FILTER_BODY], - ], -) -@pytest.mark.parametrize( - "psort,psort_body", - [ - [None, None], - [RUN_NUMBER_SORT, RUN_NUMBER_SORT_BODY], - [STARTED_AT_SORT, STARTED_AT_SORT_BODY], - [DURATION_SORT, DURATION_SORT_BODY], - [PARAM_SORT, PARAM_SORT_BODY], - [TAG_SORT, TAG_SORT_BODY], - [METRIC_SORT, METRIC_SORT_BODY], - [MULTI_SORT, MULTI_SORT_BODY], - ], -) -def test_query_runs_schema(mocker, pfilter, psort, pfilter_body, psort_body): - queryRunsObj = QueryRuns(pfilter, psort, PAGE) - expected_json = { - "filter": pfilter_body, - "sort": psort_body, - "page": PAGE_BODY, - } - data = QueryRunsSchema().dump(queryRunsObj) - assert data == expected_json - +# PROJECT_ID_FILTER = SingleFilter( +# SingleFilterBy.PROJECT_ID, None, SingleFilterOperator.EQUAL_TO, PROJECT_ID +# ) +# PROJECT_ID_FILTER_BODY = { +# "by": "projectId", +# "key": None, +# "operator": "eq", +# "value": str(PROJECT_ID), +# } + +# EXPERIMENT_ID_FILTER = SingleFilter( +# SingleFilterBy.EXPERIMENT_ID, None, SingleFilterOperator.EQUAL_TO, 1 +# ) +# EXPERIMENT_ID_BODY = { +# "by": "experimentId", +# "key": None, +# "operator": "eq", +# "value": "1", +# } + +# RUN_ID_FILTER = SingleFilter( +# SingleFilterBy.RUN_ID, +# None, +# SingleFilterOperator.EQUAL_TO, +# EXPERIMENT_RUN_ID, +# ) +# RUN_ID_BODY = { +# "by": "runId", +# "key": None, +# "operator": "eq", +# "value": str(EXPERIMENT_RUN_ID), +# } + +# DELETED_AT_FILTER = SingleFilter( +# SingleFilterBy.DELETED_AT, None, SingleFilterOperator.EQUAL_TO, DELETED_AT +# ) +# DELETED_AT_BODY = { +# "by": "deletedAt", +# "key": None, +# "operator": "eq", +# "value": DELETED_AT_STRING_PYTHON, +# } + +# TAG_FILTER = SingleFilter( +# SingleFilterBy.TAG, "tag_key", SingleFilterOperator.EQUAL_TO, "tag_value" +# ) +# TAG_FILTER_BODY = { +# "by": "tag", +# "key": "tag_key", +# "operator": "eq", +# "value": "tag_value", +# } + +# PARAM_FILTER = SingleFilter( +# SingleFilterBy.PARAM, +# "param_key", +# SingleFilterOperator.EQUAL_TO, +# "param_value", +# ) +# PARAM_FILTER_BODY = { +# "by": "param", +# "key": "param_key", +# "operator": "eq", +# "value": "param_value", +# } + +# METRIC_FILTER = SingleFilter( +# SingleFilterBy.METRIC, +# "metric_key", +# SingleFilterOperator.EQUAL_TO, +# "metric_value", +# ) +# METRIC_FILTER_BODY = { +# "by": "metric", +# "key": "metric_key", +# "operator": "eq", +# "value": "metric_value", +# } + +# AND_FILTER = CompoundFilter( +# operator=CompoundFilterOperator.AND, +# conditions=[ +# SingleFilter( +# SingleFilterBy.TAG, +# "tag_key", +# SingleFilterOperator.EQUAL_TO, +# "tag_value", +# ), +# None, +# ], +# ) +# AND_FILTER_BODY = { +# "operator": "and", +# "conditions": [ +# { +# "by": "tag", +# "key": "tag_key", +# "operator": "eq", +# "value": "tag_value", +# }, +# None, +# ], +# } + +# OR_FILTER = CompoundFilter( +# operator=CompoundFilterOperator.OR, +# conditions=[ +# SingleFilter( +# SingleFilterBy.TAG, +# "tag_key", +# SingleFilterOperator.EQUAL_TO, +# "tag_value", +# ), +# None, +# ], +# ) +# OR_FILTER_BODY = { +# "operator": "or", +# "conditions": [ +# { +# "by": "tag", +# "key": "tag_key", +# "operator": "eq", +# "value": "tag_value", +# }, +# None, +# ], +# } + +# RUN_NUMBER_SORT = [Sort(SortBy.RUN_NUMBER, None, SortOrder.ASC)] +# RUN_NUMBER_SORT_BODY = [{"by": "runNumber", "key": None, "order": "asc"}] + +# STARTED_AT_SORT = [Sort(SortBy.STARTED_AT, None, SortOrder.DESC)] +# STARTED_AT_SORT_BODY = [{"by": "startedAt", "key": None, "order": "desc"}] + +# DURATION_SORT = [Sort(SortBy.DURATION, None, SortOrder.DESC)] +# DURATION_SORT_BODY = [{"by": "duration", "key": None, "order": "desc"}] + +# PARAM_SORT = [Sort(SortBy.PARAM, "param_key", SortOrder.DESC)] +# PARAM_SORT_BODY = [{"by": "param", "key": "param_key", "order": "desc"}] + +# TAG_SORT = [Sort(SortBy.TAG, "tag_key", SortOrder.DESC)] +# TAG_SORT_BODY = [{"by": "tag", "key": "tag_key", "order": "desc"}] + +# METRIC_SORT = [Sort(SortBy.METRIC, "metric_key", SortOrder.DESC)] +# METRIC_SORT_BODY = [{"by": "metric", "key": "metric_key", "order": "desc"}] + +# MULTI_SORT = [ +# Sort(SortBy.PARAM, "param_key", SortOrder.ASC), +# Sort(SortBy.RUN_NUMBER, None, SortOrder.DESC), +# ] +# MULTI_SORT_BODY = [ +# {"by": "param", "key": "param_key", "order": "asc"}, +# {"by": "runNumber", "key": None, "order": "desc"}, +# ] + + +# @pytest.mark.parametrize( +# "pfilter,pfilter_body", +# [ +# [None, None], +# [PROJECT_ID_FILTER, PROJECT_ID_FILTER_BODY], +# [EXPERIMENT_ID_FILTER, EXPERIMENT_ID_BODY], +# [RUN_ID_FILTER, RUN_ID_BODY], +# [DELETED_AT_FILTER, DELETED_AT_BODY], +# [TAG_FILTER, TAG_FILTER_BODY], +# [PARAM_FILTER, PARAM_FILTER_BODY], +# [METRIC_FILTER, METRIC_FILTER_BODY], +# [AND_FILTER, AND_FILTER_BODY], +# ], +# ) +# @pytest.mark.parametrize( +# "psort,psort_body", +# [ +# [None, None], +# [RUN_NUMBER_SORT, RUN_NUMBER_SORT_BODY], +# [STARTED_AT_SORT, STARTED_AT_SORT_BODY], +# [DURATION_SORT, DURATION_SORT_BODY], +# [PARAM_SORT, PARAM_SORT_BODY], +# [TAG_SORT, TAG_SORT_BODY], +# [METRIC_SORT, METRIC_SORT_BODY], +# [MULTI_SORT, MULTI_SORT_BODY], +# ], +# ) +# def test_query_runs_schema(mocker, pfilter, psort, pfilter_body, psort_body): +# queryRunsObj = QueryRuns(pfilter, psort, PAGE) +# expected_json = { +# "filter": pfilter_body, +# "sort": psort_body, +# "page": PAGE_BODY, +# } +# data = QueryRunsSchema().dump(queryRunsObj) +# assert data == expected_json + +def test_sort_validation(mocker): + with pytest.raises( + ValueError, match="key must be none for type {}".format(SortBy.RUN_NUMBER) + ): + Sort(SortBy.RUN_NUMBER, "invalid_number", SortOrder.ASC) @pytest.mark.parametrize("description", [None, "experiment description"]) @pytest.mark.parametrize("artifact_location", [None, "s3://mybucket"]) From 761e970d9acd32b982c9c3bafa875d11315a2629 Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Mon, 13 May 2019 17:35:50 +0100 Subject: [PATCH 11/60] Fix validation in Sort constructor --- faculty/clients/experiment.py | 7 +- tests/clients/test_experiment.py | 396 +++++++++++++++---------------- 2 files changed, 202 insertions(+), 201 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 43b0c83a..816b3925 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -138,15 +138,17 @@ class CompoundFilterOperator(Enum): _Sort = namedtuple("_Sort", ["by", "key", "order"]) + class Sort(_Sort): - def __new__(self, by, key, order): + def __new__(cls, by, key, order): if by in [SortBy.STARTED_AT, SortBy.RUN_NUMBER, SortBy.DURATION] and key is not None: raise ValueError("key must be none for type {}".format(by)) elif by in [SortBy.TAG, SortBy.PARAM, SortBy.METRIC] and key is None: raise ValueError("key must not be none for type {}".format(by)) - self = super(Sort, self).__init__(by, key, order) + self = super(Sort, cls).__new__(cls, by, key, order) return self + class SortBy(Enum): STARTED_AT = "startedAt" RUN_NUMBER = "runNumber" @@ -343,7 +345,6 @@ def _serialize(self, value, attr, obj, **kwargs): elif obj.by == SingleFilterBy.DELETED_AT: return marshmallow_utils.from_iso_datetime(str(value)).isoformat() else: - print(value) raise RunQueryFilterValidation( "Validation error serialising run query filter", value ) diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index b37b486e..4192c6f4 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -340,210 +340,210 @@ def test_experiment_run_data_schema_multiple(): assert data == {"tags": [TAG_BODY, OTHER_TAG_BODY]} -# check that all three fields are nullable -# check all types of filter -# check all types of sort - -# PROJECT_ID_FILTER = SingleFilter( -# SingleFilterBy.PROJECT_ID, None, SingleFilterOperator.EQUAL_TO, PROJECT_ID -# ) -# PROJECT_ID_FILTER_BODY = { -# "by": "projectId", -# "key": None, -# "operator": "eq", -# "value": str(PROJECT_ID), -# } - -# EXPERIMENT_ID_FILTER = SingleFilter( -# SingleFilterBy.EXPERIMENT_ID, None, SingleFilterOperator.EQUAL_TO, 1 -# ) -# EXPERIMENT_ID_BODY = { -# "by": "experimentId", -# "key": None, -# "operator": "eq", -# "value": "1", -# } - -# RUN_ID_FILTER = SingleFilter( -# SingleFilterBy.RUN_ID, -# None, -# SingleFilterOperator.EQUAL_TO, -# EXPERIMENT_RUN_ID, -# ) -# RUN_ID_BODY = { -# "by": "runId", -# "key": None, -# "operator": "eq", -# "value": str(EXPERIMENT_RUN_ID), -# } - -# DELETED_AT_FILTER = SingleFilter( -# SingleFilterBy.DELETED_AT, None, SingleFilterOperator.EQUAL_TO, DELETED_AT -# ) -# DELETED_AT_BODY = { -# "by": "deletedAt", -# "key": None, -# "operator": "eq", -# "value": DELETED_AT_STRING_PYTHON, -# } - -# TAG_FILTER = SingleFilter( -# SingleFilterBy.TAG, "tag_key", SingleFilterOperator.EQUAL_TO, "tag_value" -# ) -# TAG_FILTER_BODY = { -# "by": "tag", -# "key": "tag_key", -# "operator": "eq", -# "value": "tag_value", -# } - -# PARAM_FILTER = SingleFilter( -# SingleFilterBy.PARAM, -# "param_key", -# SingleFilterOperator.EQUAL_TO, -# "param_value", -# ) -# PARAM_FILTER_BODY = { -# "by": "param", -# "key": "param_key", -# "operator": "eq", -# "value": "param_value", -# } - -# METRIC_FILTER = SingleFilter( -# SingleFilterBy.METRIC, -# "metric_key", -# SingleFilterOperator.EQUAL_TO, -# "metric_value", -# ) -# METRIC_FILTER_BODY = { -# "by": "metric", -# "key": "metric_key", -# "operator": "eq", -# "value": "metric_value", -# } - -# AND_FILTER = CompoundFilter( -# operator=CompoundFilterOperator.AND, -# conditions=[ -# SingleFilter( -# SingleFilterBy.TAG, -# "tag_key", -# SingleFilterOperator.EQUAL_TO, -# "tag_value", -# ), -# None, -# ], -# ) -# AND_FILTER_BODY = { -# "operator": "and", -# "conditions": [ -# { -# "by": "tag", -# "key": "tag_key", -# "operator": "eq", -# "value": "tag_value", -# }, -# None, -# ], -# } - -# OR_FILTER = CompoundFilter( -# operator=CompoundFilterOperator.OR, -# conditions=[ -# SingleFilter( -# SingleFilterBy.TAG, -# "tag_key", -# SingleFilterOperator.EQUAL_TO, -# "tag_value", -# ), -# None, -# ], -# ) -# OR_FILTER_BODY = { -# "operator": "or", -# "conditions": [ -# { -# "by": "tag", -# "key": "tag_key", -# "operator": "eq", -# "value": "tag_value", -# }, -# None, -# ], -# } - -# RUN_NUMBER_SORT = [Sort(SortBy.RUN_NUMBER, None, SortOrder.ASC)] -# RUN_NUMBER_SORT_BODY = [{"by": "runNumber", "key": None, "order": "asc"}] - -# STARTED_AT_SORT = [Sort(SortBy.STARTED_AT, None, SortOrder.DESC)] -# STARTED_AT_SORT_BODY = [{"by": "startedAt", "key": None, "order": "desc"}] - -# DURATION_SORT = [Sort(SortBy.DURATION, None, SortOrder.DESC)] -# DURATION_SORT_BODY = [{"by": "duration", "key": None, "order": "desc"}] - -# PARAM_SORT = [Sort(SortBy.PARAM, "param_key", SortOrder.DESC)] -# PARAM_SORT_BODY = [{"by": "param", "key": "param_key", "order": "desc"}] - -# TAG_SORT = [Sort(SortBy.TAG, "tag_key", SortOrder.DESC)] -# TAG_SORT_BODY = [{"by": "tag", "key": "tag_key", "order": "desc"}] - -# METRIC_SORT = [Sort(SortBy.METRIC, "metric_key", SortOrder.DESC)] -# METRIC_SORT_BODY = [{"by": "metric", "key": "metric_key", "order": "desc"}] - -# MULTI_SORT = [ -# Sort(SortBy.PARAM, "param_key", SortOrder.ASC), -# Sort(SortBy.RUN_NUMBER, None, SortOrder.DESC), -# ] -# MULTI_SORT_BODY = [ -# {"by": "param", "key": "param_key", "order": "asc"}, -# {"by": "runNumber", "key": None, "order": "desc"}, -# ] - - -# @pytest.mark.parametrize( -# "pfilter,pfilter_body", -# [ -# [None, None], -# [PROJECT_ID_FILTER, PROJECT_ID_FILTER_BODY], -# [EXPERIMENT_ID_FILTER, EXPERIMENT_ID_BODY], -# [RUN_ID_FILTER, RUN_ID_BODY], -# [DELETED_AT_FILTER, DELETED_AT_BODY], -# [TAG_FILTER, TAG_FILTER_BODY], -# [PARAM_FILTER, PARAM_FILTER_BODY], -# [METRIC_FILTER, METRIC_FILTER_BODY], -# [AND_FILTER, AND_FILTER_BODY], -# ], -# ) -# @pytest.mark.parametrize( -# "psort,psort_body", -# [ -# [None, None], -# [RUN_NUMBER_SORT, RUN_NUMBER_SORT_BODY], -# [STARTED_AT_SORT, STARTED_AT_SORT_BODY], -# [DURATION_SORT, DURATION_SORT_BODY], -# [PARAM_SORT, PARAM_SORT_BODY], -# [TAG_SORT, TAG_SORT_BODY], -# [METRIC_SORT, METRIC_SORT_BODY], -# [MULTI_SORT, MULTI_SORT_BODY], -# ], -# ) -# def test_query_runs_schema(mocker, pfilter, psort, pfilter_body, psort_body): -# queryRunsObj = QueryRuns(pfilter, psort, PAGE) -# expected_json = { -# "filter": pfilter_body, -# "sort": psort_body, -# "page": PAGE_BODY, -# } -# data = QueryRunsSchema().dump(queryRunsObj) -# assert data == expected_json +PROJECT_ID_FILTER = SingleFilter( + SingleFilterBy.PROJECT_ID, None, SingleFilterOperator.EQUAL_TO, PROJECT_ID +) +PROJECT_ID_FILTER_BODY = { + "by": "projectId", + "key": None, + "operator": "eq", + "value": str(PROJECT_ID), +} + +EXPERIMENT_ID_FILTER = SingleFilter( + SingleFilterBy.EXPERIMENT_ID, None, SingleFilterOperator.EQUAL_TO, 1 +) +EXPERIMENT_ID_BODY = { + "by": "experimentId", + "key": None, + "operator": "eq", + "value": "1", +} + +RUN_ID_FILTER = SingleFilter( + SingleFilterBy.RUN_ID, + None, + SingleFilterOperator.EQUAL_TO, + EXPERIMENT_RUN_ID, +) +RUN_ID_BODY = { + "by": "runId", + "key": None, + "operator": "eq", + "value": str(EXPERIMENT_RUN_ID), +} + +DELETED_AT_FILTER = SingleFilter( + SingleFilterBy.DELETED_AT, None, SingleFilterOperator.EQUAL_TO, DELETED_AT +) +DELETED_AT_BODY = { + "by": "deletedAt", + "key": None, + "operator": "eq", + "value": DELETED_AT_STRING_PYTHON, +} + +TAG_FILTER = SingleFilter( + SingleFilterBy.TAG, "tag_key", SingleFilterOperator.EQUAL_TO, "tag_value" +) +TAG_FILTER_BODY = { + "by": "tag", + "key": "tag_key", + "operator": "eq", + "value": "tag_value", +} + +PARAM_FILTER = SingleFilter( + SingleFilterBy.PARAM, + "param_key", + SingleFilterOperator.EQUAL_TO, + "param_value", +) +PARAM_FILTER_BODY = { + "by": "param", + "key": "param_key", + "operator": "eq", + "value": "param_value", +} + +METRIC_FILTER = SingleFilter( + SingleFilterBy.METRIC, + "metric_key", + SingleFilterOperator.EQUAL_TO, + "metric_value", +) +METRIC_FILTER_BODY = { + "by": "metric", + "key": "metric_key", + "operator": "eq", + "value": "metric_value", +} + +AND_FILTER = CompoundFilter( + operator=CompoundFilterOperator.AND, + conditions=[ + SingleFilter( + SingleFilterBy.TAG, + "tag_key", + SingleFilterOperator.EQUAL_TO, + "tag_value", + ), + None, + ], +) +AND_FILTER_BODY = { + "operator": "and", + "conditions": [ + { + "by": "tag", + "key": "tag_key", + "operator": "eq", + "value": "tag_value", + }, + None, + ], +} + +OR_FILTER = CompoundFilter( + operator=CompoundFilterOperator.OR, + conditions=[ + SingleFilter( + SingleFilterBy.TAG, + "tag_key", + SingleFilterOperator.EQUAL_TO, + "tag_value", + ), + None, + ], +) +OR_FILTER_BODY = { + "operator": "or", + "conditions": [ + { + "by": "tag", + "key": "tag_key", + "operator": "eq", + "value": "tag_value", + }, + None, + ], +} + +RUN_NUMBER_SORT = [Sort(SortBy.RUN_NUMBER, None, SortOrder.ASC)] +RUN_NUMBER_SORT_BODY = [{"by": "runNumber", "key": None, "order": "asc"}] + +STARTED_AT_SORT = [Sort(SortBy.STARTED_AT, None, SortOrder.DESC)] +STARTED_AT_SORT_BODY = [{"by": "startedAt", "key": None, "order": "desc"}] + +DURATION_SORT = [Sort(SortBy.DURATION, None, SortOrder.DESC)] +DURATION_SORT_BODY = [{"by": "duration", "key": None, "order": "desc"}] + +PARAM_SORT = [Sort(SortBy.PARAM, "param_key", SortOrder.DESC)] +PARAM_SORT_BODY = [{"by": "param", "key": "param_key", "order": "desc"}] + +TAG_SORT = [Sort(SortBy.TAG, "tag_key", SortOrder.DESC)] +TAG_SORT_BODY = [{"by": "tag", "key": "tag_key", "order": "desc"}] + +METRIC_SORT = [Sort(SortBy.METRIC, "metric_key", SortOrder.DESC)] +METRIC_SORT_BODY = [{"by": "metric", "key": "metric_key", "order": "desc"}] + +MULTI_SORT = [ + Sort(SortBy.PARAM, "param_key", SortOrder.ASC), + Sort(SortBy.RUN_NUMBER, None, SortOrder.DESC), +] +MULTI_SORT_BODY = [ + {"by": "param", "key": "param_key", "order": "asc"}, + {"by": "runNumber", "key": None, "order": "desc"}, +] + + +@pytest.mark.parametrize( + "pfilter,pfilter_body", + [ + [None, None], + [PROJECT_ID_FILTER, PROJECT_ID_FILTER_BODY], + [EXPERIMENT_ID_FILTER, EXPERIMENT_ID_BODY], + [RUN_ID_FILTER, RUN_ID_BODY], + [DELETED_AT_FILTER, DELETED_AT_BODY], + [TAG_FILTER, TAG_FILTER_BODY], + [PARAM_FILTER, PARAM_FILTER_BODY], + [METRIC_FILTER, METRIC_FILTER_BODY], + [AND_FILTER, AND_FILTER_BODY], + ], +) +@pytest.mark.parametrize( + "psort,psort_body", + [ + [None, None], + [RUN_NUMBER_SORT, RUN_NUMBER_SORT_BODY], + [STARTED_AT_SORT, STARTED_AT_SORT_BODY], + [DURATION_SORT, DURATION_SORT_BODY], + [PARAM_SORT, PARAM_SORT_BODY], + [TAG_SORT, TAG_SORT_BODY], + [METRIC_SORT, METRIC_SORT_BODY], + [MULTI_SORT, MULTI_SORT_BODY], + ], +) +def test_query_runs_schema(mocker, pfilter, psort, pfilter_body, psort_body): + queryRunsObj = QueryRuns(pfilter, psort, PAGE) + s = Sort(SortBy.RUN_NUMBER, None, SortOrder.ASC) + expected_json = { + "filter": pfilter_body, + "sort": psort_body, + "page": PAGE_BODY, + } + data = QueryRunsSchema().dump(queryRunsObj) + assert data == expected_json + def test_sort_validation(mocker): with pytest.raises( - ValueError, match="key must be none for type {}".format(SortBy.RUN_NUMBER) + ValueError, + match="key must be none for type {}".format(SortBy.RUN_NUMBER) ): Sort(SortBy.RUN_NUMBER, "invalid_number", SortOrder.ASC) + @pytest.mark.parametrize("description", [None, "experiment description"]) @pytest.mark.parametrize("artifact_location", [None, "s3://mybucket"]) def test_experiment_client_create(mocker, description, artifact_location): From 5ba48da697f9a1981e0f9615c53200c9267f5920 Mon Sep 17 00:00:00 2001 From: Hailey Fong Date: Mon, 13 May 2019 18:11:53 +0100 Subject: [PATCH 12/60] Add validation for filters --- faculty/clients/experiment.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 816b3925..9c77e8a7 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -106,7 +106,16 @@ class ExperimentRunStatus(Enum): "RestoreExperimentRunsResponse", ["restored_run_ids", "conflicted_run_ids"] ) -SingleFilter = namedtuple("SingleFilter", ["by", "key", "operator", "value"]) +_SingleFilter = namedtuple("_SingleFilter", ["by", "key", "operator", "value"]) + +class SingleFilter(_SingleFilter): + def __new__(cls, by, key, operator, value): + if by.has_key() and key is None: + raise ValueError("key must not be none for a {} filter".format(by)) + elif not by.has_key() and key is not None: + raise ValueError("key must be none for a {} filter".format(by)) + return super(SingleFilter, cls).__new__(cls, by, key, operator, value) + CompoundFilter = namedtuple("CompoundFilter", ["operator", "conditions"]) @@ -130,6 +139,9 @@ class SingleFilterBy(Enum): PARAM = "param" METRIC = "metric" + def has_key(self): + return self in [SingleFilterBy.TAG, SingleFilterBy.PARAM, SingleFilterBy.METRIC] + class CompoundFilterOperator(Enum): AND = "and" @@ -141,12 +153,11 @@ class CompoundFilterOperator(Enum): class Sort(_Sort): def __new__(cls, by, key, order): - if by in [SortBy.STARTED_AT, SortBy.RUN_NUMBER, SortBy.DURATION] and key is not None: + if by.has_key() and key is None: raise ValueError("key must be none for type {}".format(by)) - elif by in [SortBy.TAG, SortBy.PARAM, SortBy.METRIC] and key is None: - raise ValueError("key must not be none for type {}".format(by)) - self = super(Sort, cls).__new__(cls, by, key, order) - return self + elif not by.has_key() and key is not None: + raise ValueError("key must be none for type {}".format(by)) + return super(Sort, cls).__new__(cls, by, key, order) class SortBy(Enum): @@ -157,6 +168,9 @@ class SortBy(Enum): PARAM = "param" METRIC = "metric" + def has_key(self): + return self in [SortBy.TAG, SortBy.PARAM, SortBy.METRIC] + class SortOrder(Enum): ASC = "asc" From 38993eebf37f47790c432288e36e4e1157b2d507 Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Tue, 14 May 2019 09:06:02 +0100 Subject: [PATCH 13/60] Reformat code and address warning --- faculty/clients/experiment.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 9c77e8a7..b09c0f27 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -108,11 +108,12 @@ class ExperimentRunStatus(Enum): _SingleFilter = namedtuple("_SingleFilter", ["by", "key", "operator", "value"]) + class SingleFilter(_SingleFilter): def __new__(cls, by, key, operator, value): - if by.has_key() and key is None: + if by.filter_has_key() and key is None: raise ValueError("key must not be none for a {} filter".format(by)) - elif not by.has_key() and key is not None: + elif not by.filter_has_key() and key is not None: raise ValueError("key must be none for a {} filter".format(by)) return super(SingleFilter, cls).__new__(cls, by, key, operator, value) @@ -139,8 +140,12 @@ class SingleFilterBy(Enum): PARAM = "param" METRIC = "metric" - def has_key(self): - return self in [SingleFilterBy.TAG, SingleFilterBy.PARAM, SingleFilterBy.METRIC] + def filter_has_key(self): + return self in { + SingleFilterBy.TAG, + SingleFilterBy.PARAM, + SingleFilterBy.METRIC + } class CompoundFilterOperator(Enum): @@ -153,9 +158,9 @@ class CompoundFilterOperator(Enum): class Sort(_Sort): def __new__(cls, by, key, order): - if by.has_key() and key is None: + if by.filter_has_key() and key is None: raise ValueError("key must be none for type {}".format(by)) - elif not by.has_key() and key is not None: + elif not by.filter_has_key() and key is not None: raise ValueError("key must be none for type {}".format(by)) return super(Sort, cls).__new__(cls, by, key, order) @@ -168,8 +173,8 @@ class SortBy(Enum): PARAM = "param" METRIC = "metric" - def has_key(self): - return self in [SortBy.TAG, SortBy.PARAM, SortBy.METRIC] + def filter_has_key(self): + return self in {SortBy.TAG, SortBy.PARAM, SortBy.METRIC} class SortOrder(Enum): From 282dd2605028311ffb79b454001afe542e0fdad3 Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Tue, 14 May 2019 09:08:03 +0100 Subject: [PATCH 14/60] Remove unused variable --- faculty/clients/experiment.py | 2 +- tests/clients/test_experiment.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index b09c0f27..eef581c8 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -144,7 +144,7 @@ def filter_has_key(self): return self in { SingleFilterBy.TAG, SingleFilterBy.PARAM, - SingleFilterBy.METRIC + SingleFilterBy.METRIC, } diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index 4192c6f4..5e6275f2 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -526,7 +526,6 @@ def test_experiment_run_data_schema_multiple(): ) def test_query_runs_schema(mocker, pfilter, psort, pfilter_body, psort_body): queryRunsObj = QueryRuns(pfilter, psort, PAGE) - s = Sort(SortBy.RUN_NUMBER, None, SortOrder.ASC) expected_json = { "filter": pfilter_body, "sort": psort_body, @@ -539,7 +538,7 @@ def test_query_runs_schema(mocker, pfilter, psort, pfilter_body, psort_body): def test_sort_validation(mocker): with pytest.raises( ValueError, - match="key must be none for type {}".format(SortBy.RUN_NUMBER) + match="key must be none for type {}".format(SortBy.RUN_NUMBER), ): Sort(SortBy.RUN_NUMBER, "invalid_number", SortOrder.ASC) From 7c414f91498e25fa1d398d47905605b932f59bda Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Tue, 14 May 2019 09:10:32 +0100 Subject: [PATCH 15/60] Rename helper methods --- faculty/clients/experiment.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index eef581c8..a5f51bc3 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -111,9 +111,9 @@ class ExperimentRunStatus(Enum): class SingleFilter(_SingleFilter): def __new__(cls, by, key, operator, value): - if by.filter_has_key() and key is None: + if by.needs_key() and key is None: raise ValueError("key must not be none for a {} filter".format(by)) - elif not by.filter_has_key() and key is not None: + elif not by.needs_key() and key is not None: raise ValueError("key must be none for a {} filter".format(by)) return super(SingleFilter, cls).__new__(cls, by, key, operator, value) @@ -140,7 +140,7 @@ class SingleFilterBy(Enum): PARAM = "param" METRIC = "metric" - def filter_has_key(self): + def needs_key(self): return self in { SingleFilterBy.TAG, SingleFilterBy.PARAM, @@ -158,9 +158,9 @@ class CompoundFilterOperator(Enum): class Sort(_Sort): def __new__(cls, by, key, order): - if by.filter_has_key() and key is None: + if by.needs_key() and key is None: raise ValueError("key must be none for type {}".format(by)) - elif not by.filter_has_key() and key is not None: + elif not by.needs_key() and key is not None: raise ValueError("key must be none for type {}".format(by)) return super(Sort, cls).__new__(cls, by, key, order) @@ -173,7 +173,7 @@ class SortBy(Enum): PARAM = "param" METRIC = "metric" - def filter_has_key(self): + def needs_key(self): return self in {SortBy.TAG, SortBy.PARAM, SortBy.METRIC} From 491cd1743477cd9b1604e54b748d2aed40c00e58 Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Tue, 14 May 2019 09:16:48 +0100 Subject: [PATCH 16/60] Use in instead of single element equality --- faculty/clients/experiment.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index a5f51bc3..1c3c42be 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -337,8 +337,7 @@ class SingleFilterValueField(fields.Field): def _is_valid_uuid(self, value, obj): return isinstance(value, uuid.UUID) and ( - obj.by == SingleFilterBy.PROJECT_ID - or obj.by == SingleFilterBy.RUN_ID + obj.by in {SingleFilterBy.PROJECT_ID, SingleFilterBy.RUN_ID} ) def _is_valid_experiment_id(self, value, obj): @@ -350,9 +349,12 @@ def _is_directly_stringifiable(self, value, obj): return ( self._is_valid_uuid(value, obj) or self._is_valid_experiment_id(value, obj) - or obj.by == SingleFilterBy.TAG - or obj.by == SingleFilterBy.PARAM - or obj.by == SingleFilterBy.METRIC + or obj.by + in { + SingleFilterBy.TAG, + SingleFilterBy.PARAM, + SingleFilterBy.METRIC, + } ) def _deserialize(self, value, attr, obj, **kwargs): From 8c371186b845f8dbabcf4c859e5c9aeffd046f34 Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Tue, 14 May 2019 10:08:44 +0100 Subject: [PATCH 17/60] Add tests to filter and sort validation --- faculty/clients/experiment.py | 12 +++++++---- tests/clients/test_experiment.py | 34 +++++++++++++++++++++++++++++++- 2 files changed, 41 insertions(+), 5 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 1c3c42be..61199285 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -112,9 +112,11 @@ class ExperimentRunStatus(Enum): class SingleFilter(_SingleFilter): def __new__(cls, by, key, operator, value): if by.needs_key() and key is None: - raise ValueError("key must not be none for a {} filter".format(by)) + raise ValueError( + "key must not be none for filter type {}".format(by) + ) elif not by.needs_key() and key is not None: - raise ValueError("key must be none for a {} filter".format(by)) + raise ValueError("key must be none for filter type {}".format(by)) return super(SingleFilter, cls).__new__(cls, by, key, operator, value) @@ -159,9 +161,11 @@ class CompoundFilterOperator(Enum): class Sort(_Sort): def __new__(cls, by, key, order): if by.needs_key() and key is None: - raise ValueError("key must be none for type {}".format(by)) + raise ValueError( + "key must not be none for sort type {}".format(by) + ) elif not by.needs_key() and key is not None: - raise ValueError("key must be none for type {}".format(by)) + raise ValueError("key must be none for sort type {}".format(by)) return super(Sort, cls).__new__(cls, by, key, order) diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index 5e6275f2..c300cd62 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -535,12 +535,44 @@ def test_query_runs_schema(mocker, pfilter, psort, pfilter_body, psort_body): assert data == expected_json +def test_single_filter_validation(mocker): + with pytest.raises( + ValueError, + match="key must be none for filter type {}".format( + SingleFilterBy.PROJECT_ID + ), + ): + SingleFilter( + SingleFilterBy.PROJECT_ID, + "invalid_key", + SingleFilterOperator.EQUAL_TO, + PROJECT_ID, + ) + with pytest.raises( + ValueError, + match="key must not be none for filter type {}".format( + SingleFilterBy.TAG + ), + ): + SingleFilter( + SingleFilterBy.TAG, + None, + SingleFilterOperator.EQUAL_TO, + "tag_value", + ) + + def test_sort_validation(mocker): with pytest.raises( ValueError, - match="key must be none for type {}".format(SortBy.RUN_NUMBER), + match="key must be none for sort type {}".format(SortBy.RUN_NUMBER), ): Sort(SortBy.RUN_NUMBER, "invalid_number", SortOrder.ASC) + with pytest.raises( + ValueError, + match="key must not be none for sort type {}".format(SortBy.TAG), + ): + Sort(SortBy.TAG, None, SortOrder.ASC) @pytest.mark.parametrize("description", [None, "experiment description"]) From 9f794a9a33c9de89ff980db6fc1cb843b9c43d5c Mon Sep 17 00:00:00 2001 From: Hailey Fong Date: Tue, 14 May 2019 11:54:27 +0100 Subject: [PATCH 18/60] Add query_runs and test --- faculty/clients/experiment.py | 46 ++++++++++++++++++++++++++++++++ tests/clients/test_experiment.py | 46 +++++++++++++++++++++++++++++++- 2 files changed, 91 insertions(+), 1 deletion(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 61199285..5ccb83c0 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -708,6 +708,52 @@ def list_runs( endpoint, ListExperimentRunsResponseSchema(), params=query_params ) + def query_runs( + self, + project_id, + filter=None, + sort=None, + start=None, + limit=None, + ): + """Query experiment runs. + + This method returns pages of runs. If less than the full number of runs + for the job is returned, the ``next`` page of the returned response + object will not be ``None``: + + >>> response = client.query_runs(project_id) + >>> response.pagination.next + Page(start=10, limit=10) + + Get all experiment runs by making successive calls to ``query_runs``, + passing the ``start`` and ``limit`` of the ``next`` page each time + until ``next`` is returned as ``None``. + + Parameters + ---------- + project_id : uuid.UUID + filter: either a SingleFilter object or a CompoundFilter object, optional + To filter runs of experiments with the given filter. By default, runs + from all experiments are returned. + sort: List[Sort], optional + Runs are ordered using sorting elements lexicographically. By default, + experiment runs are sorted by their startedAt value. + start : int, optional + The (zero-indexed) starting point of runs to retrieve. + limit : int, optional + The maximum number of runs to retrieve. + + Returns + ------- + ListExperimentRunsResponse + """ + endpoint = "/project/{}/run/query".format(project_id) + # runs_query = QueryRuns(filter, sort, Page(start, limit)) + payload = QueryRunsSchema().dump(QueryRuns(filter, sort, Page(start, limit))) + return self._post(endpoint, ListExperimentRunsResponseSchema(), json=payload) + + def log_run_data( self, project_id, run_id, metrics=None, params=None, tags=None ): diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index c300cd62..da6a2c5b 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -536,7 +536,7 @@ def test_query_runs_schema(mocker, pfilter, psort, pfilter_body, psort_body): def test_single_filter_validation(mocker): - with pytest.raises( + with pytest.raises( ValueError, match="key must be none for filter type {}".format( SingleFilterBy.PROJECT_ID @@ -908,6 +908,50 @@ def test_experiment_client_list_runs_page(mocker): ) +def test_experiment_client_query_runs(mocker): + mocker.patch.object( + ExperimentClient, "_post", return_value=LIST_EXPERIMENT_RUNS_RESPONSE + ) + response_schema_mock = mocker.patch( + "faculty.clients.experiment.ListExperimentRunsResponseSchema" + ) + request_schema_mock = mocker.patch( + "faculty.clients.experiment.QueryRunsSchema" + ) + dump_mock = request_schema_mock.return_value.dump + + test_filter = SingleFilter( + SingleFilterBy.EXPERIMENT_ID, + None, + SingleFilterOperator.EQUAL_TO, + "2" + ) + test_sort = [Sort(SortBy.METRIC, "metric_key", SortOrder.ASC)] + + client = ExperimentClient(mocker.Mock()) + list_result = client.query_runs( + PROJECT_ID, + filter=test_filter, + sort=test_sort, + start=20, + limit=10 + ) + + assert list_result == LIST_EXPERIMENT_RUNS_RESPONSE + + request_schema_mock.assert_called_once_with() + dump_mock.assert_called_once_with( + QueryRuns(test_filter, test_sort, Page(20, 10)) + ) + response_schema_mock.assert_called_once_with() + ExperimentClient._post.assert_called_once_with( + "/project/{}/run/query".format(PROJECT_ID), + response_schema_mock.return_value, + json=dump_mock.return_value, + ) + + + def test_log_run_data(mocker): mocker.patch.object(ExperimentClient, "_patch_raw") run_data_schema_mock = mocker.patch( From dae0209ce1aef6f118d0184c334a439f21092b71 Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Tue, 14 May 2019 11:59:36 +0100 Subject: [PATCH 19/60] Reformat code correctly --- faculty/clients/experiment.py | 26 ++++++++++++-------------- tests/clients/test_experiment.py | 16 ++++------------ 2 files changed, 16 insertions(+), 26 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 5ccb83c0..2aa64d21 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -709,12 +709,7 @@ def list_runs( ) def query_runs( - self, - project_id, - filter=None, - sort=None, - start=None, - limit=None, + self, project_id, filter=None, sort=None, start=None, limit=None ): """Query experiment runs. @@ -733,12 +728,12 @@ def query_runs( Parameters ---------- project_id : uuid.UUID - filter: either a SingleFilter object or a CompoundFilter object, optional - To filter runs of experiments with the given filter. By default, runs - from all experiments are returned. + filter: either SingleFilter object or CompoundFilter object, optional + To filter runs of experiments with the given filter. By default, + runs from all experiments are returned. sort: List[Sort], optional - Runs are ordered using sorting elements lexicographically. By default, - experiment runs are sorted by their startedAt value. + Runs are ordered using sorting elements lexicographically. By + default, experiment runs are sorted by their startedAt value. start : int, optional The (zero-indexed) starting point of runs to retrieve. limit : int, optional @@ -750,9 +745,12 @@ def query_runs( """ endpoint = "/project/{}/run/query".format(project_id) # runs_query = QueryRuns(filter, sort, Page(start, limit)) - payload = QueryRunsSchema().dump(QueryRuns(filter, sort, Page(start, limit))) - return self._post(endpoint, ListExperimentRunsResponseSchema(), json=payload) - + payload = QueryRunsSchema().dump( + QueryRuns(filter, sort, Page(start, limit)) + ) + return self._post( + endpoint, ListExperimentRunsResponseSchema(), json=payload + ) def log_run_data( self, project_id, run_id, metrics=None, params=None, tags=None diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index da6a2c5b..569f2cec 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -536,7 +536,7 @@ def test_query_runs_schema(mocker, pfilter, psort, pfilter_body, psort_body): def test_single_filter_validation(mocker): - with pytest.raises( + with pytest.raises( ValueError, match="key must be none for filter type {}".format( SingleFilterBy.PROJECT_ID @@ -921,27 +921,20 @@ def test_experiment_client_query_runs(mocker): dump_mock = request_schema_mock.return_value.dump test_filter = SingleFilter( - SingleFilterBy.EXPERIMENT_ID, - None, - SingleFilterOperator.EQUAL_TO, - "2" + SingleFilterBy.EXPERIMENT_ID, None, SingleFilterOperator.EQUAL_TO, "2" ) test_sort = [Sort(SortBy.METRIC, "metric_key", SortOrder.ASC)] client = ExperimentClient(mocker.Mock()) list_result = client.query_runs( - PROJECT_ID, - filter=test_filter, - sort=test_sort, - start=20, - limit=10 + PROJECT_ID, filter=test_filter, sort=test_sort, start=20, limit=10 ) assert list_result == LIST_EXPERIMENT_RUNS_RESPONSE request_schema_mock.assert_called_once_with() dump_mock.assert_called_once_with( - QueryRuns(test_filter, test_sort, Page(20, 10)) + QueryRuns(test_filter, test_sort, Page(20, 10)) ) response_schema_mock.assert_called_once_with() ExperimentClient._post.assert_called_once_with( @@ -951,7 +944,6 @@ def test_experiment_client_query_runs(mocker): ) - def test_log_run_data(mocker): mocker.patch.object(ExperimentClient, "_patch_raw") run_data_schema_mock = mocker.patch( From 5767a87775cd3755c6007f7883d116117cda736d Mon Sep 17 00:00:00 2001 From: Hailey Fong Date: Tue, 14 May 2019 15:04:32 +0100 Subject: [PATCH 20/60] Modify deserialise function in SingleFilterValueField --- faculty/clients/experiment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 2aa64d21..df745d5e 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -361,8 +361,8 @@ def _is_directly_stringifiable(self, value, obj): } ) - def _deserialize(self, value, attr, obj, **kwargs): - pass + def _deserialize(self, value, attr, data, **kwargs): + return value def _serialize(self, value, attr, obj, **kwargs): if self._is_directly_stringifiable(value, obj): From faabf2a635f98ed26eb3b5a404cb8eeb5ac3acb8 Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Wed, 15 May 2019 12:03:23 +0100 Subject: [PATCH 21/60] Only create page if both start and limit are not None --- faculty/clients/experiment.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index df745d5e..1c436d94 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -744,9 +744,11 @@ def query_runs( ListExperimentRunsResponse """ endpoint = "/project/{}/run/query".format(project_id) - # runs_query = QueryRuns(filter, sort, Page(start, limit)) + page = None + if start is not None and limit is not None: + page = Page(start, limit) payload = QueryRunsSchema().dump( - QueryRuns(filter, sort, Page(start, limit)) + QueryRuns(filter, sort, page) ) return self._post( endpoint, ListExperimentRunsResponseSchema(), json=payload From cbfd76163153079ce248d489e1cba7a09d041c33 Mon Sep 17 00:00:00 2001 From: Hailey Fong Date: Wed, 15 May 2019 14:47:06 +0100 Subject: [PATCH 22/60] Throw a RunQueryFilterValidation error during serialisation if None is passed to conditions --- faculty/clients/experiment.py | 11 +++++++--- tests/clients/test_experiment.py | 35 ++++++++++++++++++++++++++------ 2 files changed, 37 insertions(+), 9 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 1c436d94..741396bd 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -380,7 +380,7 @@ class FilterField(fields.Field): Field that serialises/deserialises a run filter. """ - def _deserialize(self, value, attr, obj, **kwargs): + def _deserialize(self, value, attr, data, **kwargs): if value is None: return None elif isinstance(value, SingleFilter): @@ -389,8 +389,13 @@ def _deserialize(self, value, attr, obj, **kwargs): return CompoundFilterSchema().load(value) def _serialize(self, value, attr, obj, **kwargs): - if value is None: + print(type(obj)) + if value is None and isinstance(obj, QueryRuns): return None + elif value is None: + raise RunQueryFilterValidation( + "Validation error serialising a None filter", value + ) if isinstance(value, SingleFilter): return SingleFilterSchema().dump(value) else: @@ -406,7 +411,7 @@ class SingleFilterSchema(BaseSchema): class CompoundFilterSchema(BaseSchema): operator = EnumField(CompoundFilterOperator, by_value=True, required=True) - conditions = fields.List(FilterField()) + conditions = fields.List(FilterField(skip_if=None)) class SortSchema(BaseSchema): diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index 569f2cec..0bd43920 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -53,6 +53,7 @@ QueryRunsSchema, RestoreExperimentRunsResponse, RestoreExperimentRunsResponseSchema, + RunQueryFilterValidation, SingleFilter, SingleFilterBy, SingleFilterOperator, @@ -353,7 +354,7 @@ def test_experiment_run_data_schema_multiple(): EXPERIMENT_ID_FILTER = SingleFilter( SingleFilterBy.EXPERIMENT_ID, None, SingleFilterOperator.EQUAL_TO, 1 ) -EXPERIMENT_ID_BODY = { +EXPERIMENT_ID_FILTER_BODY = { "by": "experimentId", "key": None, "operator": "eq", @@ -428,7 +429,6 @@ def test_experiment_run_data_schema_multiple(): SingleFilterOperator.EQUAL_TO, "tag_value", ), - None, ], ) AND_FILTER_BODY = { @@ -440,7 +440,6 @@ def test_experiment_run_data_schema_multiple(): "operator": "eq", "value": "tag_value", }, - None, ], } @@ -453,7 +452,6 @@ def test_experiment_run_data_schema_multiple(): SingleFilterOperator.EQUAL_TO, "tag_value", ), - None, ], ) OR_FILTER_BODY = { @@ -465,7 +463,6 @@ def test_experiment_run_data_schema_multiple(): "operator": "eq", "value": "tag_value", }, - None, ], } @@ -502,13 +499,14 @@ def test_experiment_run_data_schema_multiple(): [ [None, None], [PROJECT_ID_FILTER, PROJECT_ID_FILTER_BODY], - [EXPERIMENT_ID_FILTER, EXPERIMENT_ID_BODY], + [EXPERIMENT_ID_FILTER, EXPERIMENT_ID_FILTER_BODY], [RUN_ID_FILTER, RUN_ID_BODY], [DELETED_AT_FILTER, DELETED_AT_BODY], [TAG_FILTER, TAG_FILTER_BODY], [PARAM_FILTER, PARAM_FILTER_BODY], [METRIC_FILTER, METRIC_FILTER_BODY], [AND_FILTER, AND_FILTER_BODY], + [OR_FILTER, OR_FILTER_BODY], ], ) @pytest.mark.parametrize( @@ -562,6 +560,31 @@ def test_single_filter_validation(mocker): ) +def test_compound_filter_validation(mocker): + with pytest.raises( + RunQueryFilterValidation, + match="Validation error serialising a None filter" + ): + filter = CompoundFilter( + operator=CompoundFilterOperator.OR, + conditions=[ + SingleFilter( + SingleFilterBy.TAG, + "tag_key", + SingleFilterOperator.EQUAL_TO, + "tag_value", + ), + None + ], + ) + + queryRunsObj = QueryRuns(filter, None, None) + QueryRunsSchema().dump(queryRunsObj) + + + + + def test_sort_validation(mocker): with pytest.raises( ValueError, From 752f2605576c267b02e13fcdeece2f989c86d9c1 Mon Sep 17 00:00:00 2001 From: Hailey Fong Date: Wed, 15 May 2019 15:07:42 +0100 Subject: [PATCH 23/60] Combine AND and OR filter in test_experiment --- tests/clients/test_experiment.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index 0bd43920..00b92b75 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -451,7 +451,18 @@ def test_experiment_run_data_schema_multiple(): "tag_key", SingleFilterOperator.EQUAL_TO, "tag_value", - ), + ), + CompoundFilter( + operator=CompoundFilterOperator.AND, + conditions=[ + SingleFilter( + SingleFilterBy.TAG, + "tag_key", + SingleFilterOperator.EQUAL_TO, + "tag_value", + ), + ], + ) ], ) OR_FILTER_BODY = { @@ -463,6 +474,17 @@ def test_experiment_run_data_schema_multiple(): "operator": "eq", "value": "tag_value", }, + { + "operator": "and", + "conditions": [ + { + "by": "tag", + "key": "tag_key", + "operator": "eq", + "value": "tag_value", + }, + ], + } ], } From 09deb7ab1a052b28730ee3c9078c0c47c9ec5564 Mon Sep 17 00:00:00 2001 From: Hailey Fong Date: Wed, 15 May 2019 15:12:53 +0100 Subject: [PATCH 24/60] ename RUN_ID_BODY AND DELETED_AT_BODY to RUN_ID_FILTER_BODY and DELETED_AT_FILTER_BODY --- tests/clients/test_experiment.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index 00b92b75..f1d09822 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -367,7 +367,7 @@ def test_experiment_run_data_schema_multiple(): SingleFilterOperator.EQUAL_TO, EXPERIMENT_RUN_ID, ) -RUN_ID_BODY = { +RUN_ID_FILTER_BODY = { "by": "runId", "key": None, "operator": "eq", @@ -377,7 +377,7 @@ def test_experiment_run_data_schema_multiple(): DELETED_AT_FILTER = SingleFilter( SingleFilterBy.DELETED_AT, None, SingleFilterOperator.EQUAL_TO, DELETED_AT ) -DELETED_AT_BODY = { +DELETED_AT_FILTER_BODY = { "by": "deletedAt", "key": None, "operator": "eq", @@ -522,8 +522,8 @@ def test_experiment_run_data_schema_multiple(): [None, None], [PROJECT_ID_FILTER, PROJECT_ID_FILTER_BODY], [EXPERIMENT_ID_FILTER, EXPERIMENT_ID_FILTER_BODY], - [RUN_ID_FILTER, RUN_ID_BODY], - [DELETED_AT_FILTER, DELETED_AT_BODY], + [RUN_ID_FILTER, RUN_ID_FILTER_BODY], + [DELETED_AT_FILTER, DELETED_AT_FILTER_BODY], [TAG_FILTER, TAG_FILTER_BODY], [PARAM_FILTER, PARAM_FILTER_BODY], [METRIC_FILTER, METRIC_FILTER_BODY], From dde11e7cdeec533ebd923d45d255aa5f05de210c Mon Sep 17 00:00:00 2001 From: Hailey Fong Date: Wed, 15 May 2019 16:12:10 +0100 Subject: [PATCH 25/60] Reformat code using black --- faculty/clients/experiment.py | 4 +--- tests/clients/test_experiment.py | 36 +++++++++++++------------------- 2 files changed, 15 insertions(+), 25 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 741396bd..3847e98d 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -752,9 +752,7 @@ def query_runs( page = None if start is not None and limit is not None: page = Page(start, limit) - payload = QueryRunsSchema().dump( - QueryRuns(filter, sort, page) - ) + payload = QueryRunsSchema().dump(QueryRuns(filter, sort, page)) return self._post( endpoint, ListExperimentRunsResponseSchema(), json=payload ) diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index f1d09822..c39cb511 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -428,18 +428,13 @@ def test_experiment_run_data_schema_multiple(): "tag_key", SingleFilterOperator.EQUAL_TO, "tag_value", - ), + ) ], ) AND_FILTER_BODY = { "operator": "and", "conditions": [ - { - "by": "tag", - "key": "tag_key", - "operator": "eq", - "value": "tag_value", - }, + {"by": "tag", "key": "tag_key", "operator": "eq", "value": "tag_value"} ], } @@ -451,18 +446,18 @@ def test_experiment_run_data_schema_multiple(): "tag_key", SingleFilterOperator.EQUAL_TO, "tag_value", - ), + ), CompoundFilter( operator=CompoundFilterOperator.AND, conditions=[ - SingleFilter( - SingleFilterBy.TAG, - "tag_key", - SingleFilterOperator.EQUAL_TO, - "tag_value", - ), + SingleFilter( + SingleFilterBy.TAG, + "tag_key", + SingleFilterOperator.EQUAL_TO, + "tag_value", + ) ], - ) + ), ], ) OR_FILTER_BODY = { @@ -482,9 +477,9 @@ def test_experiment_run_data_schema_multiple(): "key": "tag_key", "operator": "eq", "value": "tag_value", - }, + } ], - } + }, ], } @@ -585,7 +580,7 @@ def test_single_filter_validation(mocker): def test_compound_filter_validation(mocker): with pytest.raises( RunQueryFilterValidation, - match="Validation error serialising a None filter" + match="Validation error serialising a None filter", ): filter = CompoundFilter( operator=CompoundFilterOperator.OR, @@ -596,7 +591,7 @@ def test_compound_filter_validation(mocker): SingleFilterOperator.EQUAL_TO, "tag_value", ), - None + None, ], ) @@ -604,9 +599,6 @@ def test_compound_filter_validation(mocker): QueryRunsSchema().dump(queryRunsObj) - - - def test_sort_validation(mocker): with pytest.raises( ValueError, From f076fe5139c66e42cff0ff3cdfccb3bc81036ce2 Mon Sep 17 00:00:00 2001 From: Elias Benussi <4412300+eliasbenussi@users.noreply.github.com> Date: Wed, 22 May 2019 12:44:09 +0200 Subject: [PATCH 26/60] Update faculty/clients/experiment.py Co-Authored-By: Andrew Crozier --- faculty/clients/experiment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 3847e98d..a3b8c77c 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -733,7 +733,7 @@ def query_runs( Parameters ---------- project_id : uuid.UUID - filter: either SingleFilter object or CompoundFilter object, optional + filter: SingleFilter or CompoundFilter, optional To filter runs of experiments with the given filter. By default, runs from all experiments are returned. sort: List[Sort], optional From 11bbdfafbb7367e87d0819e31b5ca8cc26c7eb4a Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Wed, 22 May 2019 11:54:42 +0100 Subject: [PATCH 27/60] Reduce duplication --- tests/clients/test_experiment.py | 53 ++++---------------------------- 1 file changed, 6 insertions(+), 47 deletions(-) diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index c39cb511..c88deaee 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -421,65 +421,24 @@ def test_experiment_run_data_schema_multiple(): } AND_FILTER = CompoundFilter( - operator=CompoundFilterOperator.AND, - conditions=[ - SingleFilter( - SingleFilterBy.TAG, - "tag_key", - SingleFilterOperator.EQUAL_TO, - "tag_value", - ) - ], + operator=CompoundFilterOperator.AND, conditions=[TAG_FILTER] ) -AND_FILTER_BODY = { - "operator": "and", - "conditions": [ - {"by": "tag", "key": "tag_key", "operator": "eq", "value": "tag_value"} - ], -} +AND_FILTER_BODY = {"operator": "and", "conditions": [TAG_FILTER_BODY]} OR_FILTER = CompoundFilter( operator=CompoundFilterOperator.OR, conditions=[ - SingleFilter( - SingleFilterBy.TAG, - "tag_key", - SingleFilterOperator.EQUAL_TO, - "tag_value", - ), + TAG_FILTER, CompoundFilter( - operator=CompoundFilterOperator.AND, - conditions=[ - SingleFilter( - SingleFilterBy.TAG, - "tag_key", - SingleFilterOperator.EQUAL_TO, - "tag_value", - ) - ], + operator=CompoundFilterOperator.AND, conditions=[PARAM_FILTER] ), ], ) OR_FILTER_BODY = { "operator": "or", "conditions": [ - { - "by": "tag", - "key": "tag_key", - "operator": "eq", - "value": "tag_value", - }, - { - "operator": "and", - "conditions": [ - { - "by": "tag", - "key": "tag_key", - "operator": "eq", - "value": "tag_value", - } - ], - }, + TAG_FILTER_BODY, + {"operator": "and", "conditions": [PARAM_FILTER_BODY]}, ], } From 5a094740eb374d19f671ffb63afff47d55e5f9fb Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Wed, 22 May 2019 11:54:54 +0100 Subject: [PATCH 28/60] Remove print statement --- faculty/clients/experiment.py | 1 - 1 file changed, 1 deletion(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index a3b8c77c..0a0228e9 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -389,7 +389,6 @@ def _deserialize(self, value, attr, data, **kwargs): return CompoundFilterSchema().load(value) def _serialize(self, value, attr, obj, **kwargs): - print(type(obj)) if value is None and isinstance(obj, QueryRuns): return None elif value is None: From a9eea182f8efaba2b1ff832cc209dc6721abe26d Mon Sep 17 00:00:00 2001 From: Hailey Fong Date: Thu, 23 May 2019 14:22:31 +0100 Subject: [PATCH 29/60] Replace RunQueryFilterValidation with marshmallow ValidationError, utilise marshmallow serialize methods for SingleFilterValueField serialisation --- faculty/clients/experiment.py | 89 +++++++++++++++++--------------- tests/clients/test_experiment.py | 64 +++++++++++++++++------ 2 files changed, 96 insertions(+), 57 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 0a0228e9..06331820 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -14,13 +14,18 @@ from collections import namedtuple from enum import Enum -import uuid -from marshmallow import fields, post_load, utils as marshmallow_utils +from marshmallow import fields, post_load, ValidationError from marshmallow_enum import EnumField from faculty.clients.base import BaseClient, BaseSchema, Conflict +error_messages = { + "invalid_param_value": "Invalid param value", + "invalid_filter_type": "Invalid filter type", + "invalid_none_filter": "A none filter", +} + class ExperimentNameConflict(Exception): def __init__(self, name): @@ -44,12 +49,6 @@ def __init__(self, message, experiment_id): self.experiment_id = experiment_id -class RunQueryFilterValidation(Exception): - def __init__(self, message, value): - super(RunQueryFilterValidation, self).__init__(message) - self.value = value - - class ExperimentRunStatus(Enum): RUNNING = "running" FINISHED = "finished" @@ -111,12 +110,24 @@ class ExperimentRunStatus(Enum): class SingleFilter(_SingleFilter): def __new__(cls, by, key, operator, value): - if by.needs_key() and key is None: + if isinstance(by, SingleFilterBy) and by.needs_key() and key is None: raise ValueError( "key must not be none for filter type {}".format(by) ) - elif not by.needs_key() and key is not None: + elif ( + isinstance(by, SingleFilterBy) + and not by.needs_key() + and key is not None + ): raise ValueError("key must be none for filter type {}".format(by)) + elif by == SingleFilterBy.PARAM and operator.is_numeric_operator(): + if not (isinstance(value, float) or isinstance(value, int)): + raise ValueError( + ( + "value can not be type {}. It has to be either an int " + + "or a float" + ).format(type(value)) + ) return super(SingleFilter, cls).__new__(cls, by, key, operator, value) @@ -132,6 +143,14 @@ class SingleFilterOperator(Enum): GREATER_THAN = "gt" GREATER_THAN_OR_EQUAL_TO = "ge" + def is_numeric_operator(self): + return self in { + SingleFilterOperator.LESS_THAN, + SingleFilterOperator.LESS_THAN_OR_EQUAL_TO, + SingleFilterOperator.GREATER_THAN, + SingleFilterOperator.GREATER_THAN_OR_EQUAL_TO, + } + class SingleFilterBy(Enum): PROJECT_ID = "projectId" @@ -339,40 +358,30 @@ class SingleFilterValueField(fields.Field): Field that serialises/deserialises a run filter. """ - def _is_valid_uuid(self, value, obj): - return isinstance(value, uuid.UUID) and ( - obj.by in {SingleFilterBy.PROJECT_ID, SingleFilterBy.RUN_ID} - ) - - def _is_valid_experiment_id(self, value, obj): - return ( - isinstance(value, int) and obj.by == SingleFilterBy.EXPERIMENT_ID - ) - - def _is_directly_stringifiable(self, value, obj): - return ( - self._is_valid_uuid(value, obj) - or self._is_valid_experiment_id(value, obj) - or obj.by - in { - SingleFilterBy.TAG, - SingleFilterBy.PARAM, - SingleFilterBy.METRIC, - } - ) - def _deserialize(self, value, attr, data, **kwargs): return value def _serialize(self, value, attr, obj, **kwargs): - if self._is_directly_stringifiable(value, obj): - return str(value) + if obj.by in {SingleFilterBy.PROJECT_ID, SingleFilterBy.RUN_ID}: + field = fields.UUID() + elif obj.by == SingleFilterBy.EXPERIMENT_ID: + field = fields.Integer() elif obj.by == SingleFilterBy.DELETED_AT: - return marshmallow_utils.from_iso_datetime(str(value)).isoformat() + field = fields.DateTime() + elif obj.by == SingleFilterBy.TAG: + field = fields.Str() + elif obj.by == SingleFilterBy.PARAM: + if isinstance(obj.value, int) or isinstance(obj.value, float): + field = fields.Float() + elif isinstance(obj.value, str): + field = fields.Str() + else: + raise ValidationError(error_messages["invalid_param_value"]) + elif obj.by == SingleFilterBy.METRIC: + field = fields.Float() else: - raise RunQueryFilterValidation( - "Validation error serialising run query filter", value - ) + raise ValidationError(error_messages["invalid_filter_type"]) + return field._serialize(value, attr, obj, **kwargs) class FilterField(fields.Field): @@ -392,9 +401,7 @@ def _serialize(self, value, attr, obj, **kwargs): if value is None and isinstance(obj, QueryRuns): return None elif value is None: - raise RunQueryFilterValidation( - "Validation error serialising a None filter", value - ) + raise ValidationError(error_messages["invalid_none_filter"]) if isinstance(value, SingleFilter): return SingleFilterSchema().dump(value) else: diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index c88deaee..acee5b57 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -53,10 +53,10 @@ QueryRunsSchema, RestoreExperimentRunsResponse, RestoreExperimentRunsResponseSchema, - RunQueryFilterValidation, SingleFilter, SingleFilterBy, SingleFilterOperator, + SingleFilterSchema, Sort, SortBy, SortOrder, @@ -358,7 +358,7 @@ def test_experiment_run_data_schema_multiple(): "by": "experimentId", "key": None, "operator": "eq", - "value": "1", + "value": 1, } RUN_ID_FILTER = SingleFilter( @@ -394,13 +394,26 @@ def test_experiment_run_data_schema_multiple(): "value": "tag_value", } -PARAM_FILTER = SingleFilter( +PARAM_NUM_FILTER = SingleFilter( + SingleFilterBy.PARAM, + "param_key", + SingleFilterOperator.GREATER_THAN_OR_EQUAL_TO, + 1.0, +) +PARAM_NUM_FILTER_BODY = { + "by": "param", + "key": "param_key", + "operator": "ge", + "value": 1.0, +} + +PARAM_TEXT_FILTER = SingleFilter( SingleFilterBy.PARAM, "param_key", SingleFilterOperator.EQUAL_TO, "param_value", ) -PARAM_FILTER_BODY = { +PARAM_TEXT_FILTER_BODY = { "by": "param", "key": "param_key", "operator": "eq", @@ -408,16 +421,13 @@ def test_experiment_run_data_schema_multiple(): } METRIC_FILTER = SingleFilter( - SingleFilterBy.METRIC, - "metric_key", - SingleFilterOperator.EQUAL_TO, - "metric_value", + SingleFilterBy.METRIC, "metric_key", SingleFilterOperator.EQUAL_TO, 1.0 ) METRIC_FILTER_BODY = { "by": "metric", "key": "metric_key", "operator": "eq", - "value": "metric_value", + "value": 1.0, } AND_FILTER = CompoundFilter( @@ -430,7 +440,7 @@ def test_experiment_run_data_schema_multiple(): conditions=[ TAG_FILTER, CompoundFilter( - operator=CompoundFilterOperator.AND, conditions=[PARAM_FILTER] + operator=CompoundFilterOperator.AND, conditions=[PARAM_TEXT_FILTER] ), ], ) @@ -438,7 +448,7 @@ def test_experiment_run_data_schema_multiple(): "operator": "or", "conditions": [ TAG_FILTER_BODY, - {"operator": "and", "conditions": [PARAM_FILTER_BODY]}, + {"operator": "and", "conditions": [PARAM_TEXT_FILTER_BODY]}, ], } @@ -479,7 +489,8 @@ def test_experiment_run_data_schema_multiple(): [RUN_ID_FILTER, RUN_ID_FILTER_BODY], [DELETED_AT_FILTER, DELETED_AT_FILTER_BODY], [TAG_FILTER, TAG_FILTER_BODY], - [PARAM_FILTER, PARAM_FILTER_BODY], + [PARAM_NUM_FILTER, PARAM_NUM_FILTER_BODY], + [PARAM_TEXT_FILTER, PARAM_TEXT_FILTER_BODY], [METRIC_FILTER, METRIC_FILTER_BODY], [AND_FILTER, AND_FILTER_BODY], [OR_FILTER, OR_FILTER_BODY], @@ -509,6 +520,17 @@ def test_query_runs_schema(mocker, pfilter, psort, pfilter_body, psort_body): assert data == expected_json +def test_single_filter_value_field_validation(mocker): + with pytest.raises(ValidationError): + singleFilterObj = SingleFilter( + SingleFilterBy.PARAM, + "param_key", + SingleFilterOperator.GREATER_THAN, + True, + ) + SingleFilterSchema().dump(singleFilterObj) + + def test_single_filter_validation(mocker): with pytest.raises( ValueError, @@ -534,13 +556,23 @@ def test_single_filter_validation(mocker): SingleFilterOperator.EQUAL_TO, "tag_value", ) + with pytest.raises( + ValueError, + match=( + "value can not be type {}. It has to be either an int " + + "or a float" + ).format(type("param_value")), + ): + SingleFilter( + SingleFilterBy.PARAM, + "param_key", + SingleFilterOperator.GREATER_THAN, + "param_value", + ) def test_compound_filter_validation(mocker): - with pytest.raises( - RunQueryFilterValidation, - match="Validation error serialising a None filter", - ): + with pytest.raises(ValidationError): filter = CompoundFilter( operator=CompoundFilterOperator.OR, conditions=[ From 6b74cffae182965ec1905753eceee6b2a8a8b5ef Mon Sep 17 00:00:00 2001 From: Hailey Fong Date: Thu, 23 May 2019 14:28:35 +0100 Subject: [PATCH 30/60] Serialize single filter value field as number when SingleFilterBy is PARAM --- faculty/clients/experiment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 06331820..1f34ace3 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -372,7 +372,7 @@ def _serialize(self, value, attr, obj, **kwargs): field = fields.Str() elif obj.by == SingleFilterBy.PARAM: if isinstance(obj.value, int) or isinstance(obj.value, float): - field = fields.Float() + field = fields.Number() elif isinstance(obj.value, str): field = fields.Str() else: From 9de8cab152249f2f7dc348cc251c952adade67ea Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Thu, 23 May 2019 17:10:26 +0100 Subject: [PATCH 31/60] Rename variable --- tests/clients/test_experiment.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index acee5b57..28374c83 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -481,7 +481,7 @@ def test_experiment_run_data_schema_multiple(): @pytest.mark.parametrize( - "pfilter,pfilter_body", + "filter,filter_body", [ [None, None], [PROJECT_ID_FILTER, PROJECT_ID_FILTER_BODY], @@ -509,10 +509,10 @@ def test_experiment_run_data_schema_multiple(): [MULTI_SORT, MULTI_SORT_BODY], ], ) -def test_query_runs_schema(mocker, pfilter, psort, pfilter_body, psort_body): - queryRunsObj = QueryRuns(pfilter, psort, PAGE) +def test_query_runs_schema(mocker, filter, psort, filter_body, psort_body): + queryRunsObj = QueryRuns(filter, psort, PAGE) expected_json = { - "filter": pfilter_body, + "filter": filter_body, "sort": psort_body, "page": PAGE_BODY, } From 1fb2461dca3bc5c62f86a8eb68281ee27da9c69d Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Thu, 23 May 2019 17:13:28 +0100 Subject: [PATCH 32/60] Rephrase docs for better clarity --- faculty/clients/experiment.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 1f34ace3..1aaabe65 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -743,8 +743,9 @@ def query_runs( To filter runs of experiments with the given filter. By default, runs from all experiments are returned. sort: List[Sort], optional - Runs are ordered using sorting elements lexicographically. By - default, experiment runs are sorted by their startedAt value. + Runs are order using the conditions in sort. The relative + importance of each condition gradually decreases in order. + By default, experiment runs are sorted by their startedAt value. start : int, optional The (zero-indexed) starting point of runs to retrieve. limit : int, optional From 463ebadf4b8e6cd0b46212f3ed50526828c0c964 Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Tue, 28 May 2019 11:14:43 +0100 Subject: [PATCH 33/60] Implement list_runs in terms of query_runs --- faculty/clients/experiment.py | 46 ++++++++++++++------ tests/clients/test_experiment.py | 73 +++++++++++++++++++------------- 2 files changed, 76 insertions(+), 43 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 1aaabe65..6b38c255 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -694,10 +694,11 @@ def list_runs( ------- ListExperimentRunsResponse """ - if lifecycle_stage is not None: - raise NotImplementedError("lifecycle_stage is not supported.") - query_params = [] + experiment_ids_filter = None + lifecycle_filter = None + filter = None + if experiment_ids is not None: if len(experiment_ids) == 0: return ListExperimentRunsResponse( @@ -706,18 +707,37 @@ def list_runs( start=0, size=0, previous=None, next=None ), ) - for experiment_id in experiment_ids: - query_params.append(("experimentId", experiment_id)) + experiment_id_filters = [ + SingleFilter( + SingleFilterBy.EXPERIMENT_ID, + None, + SingleFilterOperator.EQUAL_TO, + value, + ) + for value in experiment_ids + ] + experiment_ids_filter = CompoundFilter( + CompoundFilterOperator.OR, experiment_id_filters + ) + if lifecycle_stage is not None: + lifecycle_filter = SingleFilter( + SingleFilterBy.DELETED_AT, + None, + SingleFilterOperator.DEFINED, + lifecycle_stage == LifecycleStage.DELETED, + ) - if start is not None: - query_params.append(("start", start)) - if limit is not None: - query_params.append(("limit", limit)) + if experiment_ids_filter is not None and lifecycle_filter is not None: + filter = CompoundFilter( + CompoundFilterOperator.AND, + [experiment_ids_filter, lifecycle_filter], + ) + elif experiment_ids_filter is not None: + filter = experiment_ids_filter + elif lifecycle_filter is not None: + filter = lifecycle_filter - endpoint = "/project/{}/run".format(project_id) - return self._get( - endpoint, ListExperimentRunsResponseSchema(), params=query_params - ) + return self.query_runs(project_id, filter, None, start, limit) def query_runs( self, project_id, filter=None, sort=None, start=None, limit=None diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index 28374c83..c8822a08 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -869,70 +869,83 @@ def test_restore_experiment_runs_response_schema_invalid(mocker): def test_experiment_client_list_runs_all(mocker): mocker.patch.object( - ExperimentClient, "_get", return_value=LIST_EXPERIMENT_RUNS_RESPONSE + ExperimentClient, "_post", return_value=LIST_EXPERIMENT_RUNS_RESPONSE ) - schema_mock = mocker.patch( + response_schema_mock = mocker.patch( "faculty.clients.experiment.ListExperimentRunsResponseSchema" ) + request_schema_mock = mocker.patch( + "faculty.clients.experiment.QueryRunsSchema" + ) + dump_mock = request_schema_mock.return_value.dump client = ExperimentClient(mocker.Mock()) list_result = client.list_runs(PROJECT_ID) assert list_result == LIST_EXPERIMENT_RUNS_RESPONSE - schema_mock.assert_called_once_with() - ExperimentClient._get.assert_called_once_with( - "/project/{}/run".format(PROJECT_ID), - schema_mock.return_value, - params=[], + response_schema_mock.assert_called_once_with() + ExperimentClient._post.assert_called_once_with( + "/project/{}/run/query".format(PROJECT_ID), + response_schema_mock.return_value, + json=dump_mock.return_value, ) def test_experiment_client_list_runs_experiments_filter(mocker): mocker.patch.object( - ExperimentClient, "_get", return_value=LIST_EXPERIMENT_RUNS_RESPONSE + ExperimentClient, "_post", return_value=LIST_EXPERIMENT_RUNS_RESPONSE ) - schema_mock = mocker.patch( + response_schema_mock = mocker.patch( "faculty.clients.experiment.ListExperimentRunsResponseSchema" ) + request_schema_mock = mocker.patch( + "faculty.clients.experiment.QueryRunsSchema" + ) + dump_mock = request_schema_mock.return_value.dump client = ExperimentClient(mocker.Mock()) list_result = client.list_runs(PROJECT_ID, experiment_ids=[123, 456]) assert list_result == LIST_EXPERIMENT_RUNS_RESPONSE - schema_mock.assert_called_once_with() - ExperimentClient._get.assert_called_once_with( - "/project/{}/run".format(PROJECT_ID), - schema_mock.return_value, - params=[("experimentId", 123), ("experimentId", 456)], - ) - - -def test_experiment_client_list_runs_experiments_filter_empty(mocker): - client = ExperimentClient(mocker.Mock()) - list_result = client.list_runs(PROJECT_ID, experiment_ids=[]) - assert list_result == ListExperimentRunsResponse( - runs=[], - pagination=Pagination(start=0, size=0, previous=None, next=None), + response_schema_mock.assert_called_once_with() + ExperimentClient._post.assert_called_once_with( + "/project/{}/run/query".format(PROJECT_ID), + response_schema_mock.return_value, + json=dump_mock.return_value, ) def test_experiment_client_list_runs_page(mocker): mocker.patch.object( - ExperimentClient, "_get", return_value=LIST_EXPERIMENT_RUNS_RESPONSE + ExperimentClient, "_post", return_value=LIST_EXPERIMENT_RUNS_RESPONSE ) - schema_mock = mocker.patch( + response_schema_mock = mocker.patch( "faculty.clients.experiment.ListExperimentRunsResponseSchema" ) + request_schema_mock = mocker.patch( + "faculty.clients.experiment.QueryRunsSchema" + ) + dump_mock = request_schema_mock.return_value.dump client = ExperimentClient(mocker.Mock()) list_result = client.list_runs(PROJECT_ID, start=20, limit=10) assert list_result == LIST_EXPERIMENT_RUNS_RESPONSE - schema_mock.assert_called_once_with() - ExperimentClient._get.assert_called_once_with( - "/project/{}/run".format(PROJECT_ID), - schema_mock.return_value, - params=[("start", 20), ("limit", 10)], + response_schema_mock.assert_called_once_with() + ExperimentClient._post.assert_called_once_with( + "/project/{}/run/query".format(PROJECT_ID), + response_schema_mock.return_value, + json=dump_mock.return_value, + ) + + +def test_experiment_client_list_runs_experiments_filter_empty(mocker): + client = ExperimentClient(mocker.Mock()) + list_result = client.list_runs(PROJECT_ID, experiment_ids=[]) + + assert list_result == ListExperimentRunsResponse( + runs=[], + pagination=Pagination(start=0, size=0, previous=None, next=None), ) From 16c6cd5aa45e3dce047db69e3a8e7348f87528ab Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Tue, 28 May 2019 11:20:35 +0100 Subject: [PATCH 34/60] Update docs --- faculty/clients/experiment.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 6b38c255..6b1c3223 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -685,6 +685,9 @@ def list_runs( To filter runs of experiments with the given IDs only. If an empty list is passed, a result with an empty list of runs is returned. By default, runs from all experiments are returned. + lifecycle_stage: LifecycleStage, optional + To filter runs of experiments in a specific lifecycle stage only. + By default, runs in any stage are returned. start : int, optional The (zero-indexed) starting point of runs to retrieve. limit : int, optional From d00dbeeceb46fd57b2926f6e8ef7fda105097fc7 Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Tue, 28 May 2019 13:02:47 +0100 Subject: [PATCH 35/60] Update delete and restore runs to use filter objects --- faculty/clients/experiment.py | 44 +++++++++++++++++++------------- tests/clients/test_experiment.py | 44 +++++++++++++++++++++++--------- 2 files changed, 58 insertions(+), 30 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 6b1c3223..08a4d472 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -891,15 +891,19 @@ def delete_runs(self, project_id, run_ids=None): deleted_run_ids=[], conflicted_run_ids=[] ) else: - payload = { - "filter": { - "operator": "or", - "conditions": [ - {"by": "runId", "operator": "eq", "value": str(run_id)} - for run_id in run_ids - ], - } - } + run_id_filters = [ + SingleFilter( + SingleFilterBy.RUN_ID, + None, + SingleFilterOperator.EQUAL_TO, + run_id, + ) + for run_id in run_ids + ] + run_ids_filter = CompoundFilter( + CompoundFilterOperator.OR, run_id_filters + ) + payload = {"filter": run_ids_filter} return self._post( endpoint, DeleteExperimentRunsResponseSchema(), json=payload @@ -932,15 +936,19 @@ def restore_runs(self, project_id, run_ids=None): restored_run_ids=[], conflicted_run_ids=[] ) else: - payload = { - "filter": { - "operator": "or", - "conditions": [ - {"by": "runId", "operator": "eq", "value": str(run_id)} - for run_id in run_ids - ], - } - } + run_id_filters = [ + SingleFilter( + SingleFilterBy.RUN_ID, + None, + SingleFilterOperator.EQUAL_TO, + run_id, + ) + for run_id in run_ids + ] + run_ids_filter = CompoundFilter( + CompoundFilterOperator.OR, run_id_filters + ) + payload = {"filter": run_ids_filter} return self._post( endpoint, RestoreExperimentRunsResponseSchema(), json=payload diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index c8822a08..929df4f2 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -1216,13 +1216,23 @@ def test_delete_runs(mocker): ) expected_payload = { - "filter": { - "operator": "or", - "conditions": [ - {"by": "runId", "operator": "eq", "value": str(run_ids[0])}, - {"by": "runId", "operator": "eq", "value": str(run_ids[1])}, + "filter": CompoundFilter( + CompoundFilterOperator.OR, + [ + SingleFilter( + SingleFilterBy.RUN_ID, + None, + SingleFilterOperator.EQUAL_TO, + run_ids[0], + ), + SingleFilter( + SingleFilterBy.RUN_ID, + None, + SingleFilterOperator.EQUAL_TO, + run_ids[1], + ), ], - } + ) } ExperimentClient._post.assert_called_once_with( @@ -1276,13 +1286,23 @@ def test_restore_runs(mocker): ) expected_payload = { - "filter": { - "operator": "or", - "conditions": [ - {"by": "runId", "operator": "eq", "value": str(run_ids[0])}, - {"by": "runId", "operator": "eq", "value": str(run_ids[1])}, + "filter": CompoundFilter( + CompoundFilterOperator.OR, + [ + SingleFilter( + SingleFilterBy.RUN_ID, + None, + SingleFilterOperator.EQUAL_TO, + run_ids[0], + ), + SingleFilter( + SingleFilterBy.RUN_ID, + None, + SingleFilterOperator.EQUAL_TO, + run_ids[1], + ), ], - } + ) } ExperimentClient._post.assert_called_once_with( From 4110c91bd9e9438c312f275f8d97c71ef122444f Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Tue, 28 May 2019 14:30:26 +0100 Subject: [PATCH 36/60] Unnest condition --- faculty/clients/experiment.py | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 08a4d472..905a4967 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -20,7 +20,7 @@ from faculty.clients.base import BaseClient, BaseSchema, Conflict -error_messages = { +_ERROR_MESSAGES = { "invalid_param_value": "Invalid param value", "invalid_filter_type": "Invalid filter type", "invalid_none_filter": "A none filter", @@ -120,14 +120,17 @@ def __new__(cls, by, key, operator, value): and key is not None ): raise ValueError("key must be none for filter type {}".format(by)) - elif by == SingleFilterBy.PARAM and operator.is_numeric_operator(): - if not (isinstance(value, float) or isinstance(value, int)): - raise ValueError( - ( - "value can not be type {}. It has to be either an int " - + "or a float" - ).format(type(value)) - ) + elif ( + by == SingleFilterBy.PARAM + and operator.is_numeric_operator() + and not (isinstance(value, float) or isinstance(value, int)) + ): + raise ValueError( + ( + "value can not be type {}. It has to be either an int " + + "or a float" + ).format(type(value)) + ) return super(SingleFilter, cls).__new__(cls, by, key, operator, value) @@ -376,11 +379,11 @@ def _serialize(self, value, attr, obj, **kwargs): elif isinstance(obj.value, str): field = fields.Str() else: - raise ValidationError(error_messages["invalid_param_value"]) + raise ValidationError(_ERROR_MESSAGES["invalid_param_value"]) elif obj.by == SingleFilterBy.METRIC: field = fields.Float() else: - raise ValidationError(error_messages["invalid_filter_type"]) + raise ValidationError(_ERROR_MESSAGES["invalid_filter_type"]) return field._serialize(value, attr, obj, **kwargs) @@ -401,7 +404,7 @@ def _serialize(self, value, attr, obj, **kwargs): if value is None and isinstance(obj, QueryRuns): return None elif value is None: - raise ValidationError(error_messages["invalid_none_filter"]) + raise ValidationError(_ERROR_MESSAGES["invalid_none_filter"]) if isinstance(value, SingleFilter): return SingleFilterSchema().dump(value) else: @@ -715,9 +718,9 @@ def list_runs( SingleFilterBy.EXPERIMENT_ID, None, SingleFilterOperator.EQUAL_TO, - value, + experiment_id, ) - for value in experiment_ids + for experiment_id in experiment_ids ] experiment_ids_filter = CompoundFilter( CompoundFilterOperator.OR, experiment_id_filters From 7a83f8cbc1ffb3874da63536031abe82901df8a8 Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Tue, 28 May 2019 16:20:38 +0100 Subject: [PATCH 37/60] Parametrise single filter construction tests --- faculty/clients/experiment.py | 2 +- tests/clients/test_experiment.py | 49 +++++++++++++++----------------- 2 files changed, 24 insertions(+), 27 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 905a4967..ad83779a 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -127,7 +127,7 @@ def __new__(cls, by, key, operator, value): ): raise ValueError( ( - "value can not be type {}. It has to be either an int " + "invalid type {}. Value has to be either an int " + "or a float" ).format(type(value)) ) diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index 929df4f2..e7a8e1e5 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -531,44 +531,41 @@ def test_single_filter_value_field_validation(mocker): SingleFilterSchema().dump(singleFilterObj) -def test_single_filter_validation(mocker): - with pytest.raises( - ValueError, - match="key must be none for filter type {}".format( - SingleFilterBy.PROJECT_ID - ), - ): - SingleFilter( +@pytest.mark.parametrize( + "by,key,op,value,err_msg", + [ + [ SingleFilterBy.PROJECT_ID, "invalid_key", SingleFilterOperator.EQUAL_TO, PROJECT_ID, - ) - with pytest.raises( - ValueError, - match="key must not be none for filter type {}".format( - SingleFilterBy.TAG - ), - ): - SingleFilter( + "key must be none for filter type {}".format( + SingleFilterBy.PROJECT_ID + ), + ], + [ SingleFilterBy.TAG, None, SingleFilterOperator.EQUAL_TO, "tag_value", - ) - with pytest.raises( - ValueError, - match=( - "value can not be type {}. It has to be either an int " - + "or a float" - ).format(type("param_value")), - ): - SingleFilter( + "key must not be none for filter type {}".format( + SingleFilterBy.TAG + ), + ], + [ SingleFilterBy.PARAM, "param_key", SingleFilterOperator.GREATER_THAN, "param_value", - ) + "invalid type {}. Value has to be either an int or a float".format( + type("param_value") + ), + ], + ], +) +def test_single_filter_validation(mocker, by, key, op, value, err_msg): + with pytest.raises(ValueError, match=err_msg): + SingleFilter(by, key, op, value) def test_compound_filter_validation(mocker): From 093d9775effcab68ca1a9630cba8869c6c466b71 Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Mon, 3 Jun 2019 11:17:24 +0100 Subject: [PATCH 38/60] Serialise as Boolean if operator is defined --- faculty/clients/experiment.py | 4 +++- tests/clients/test_experiment.py | 11 +++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index ad83779a..a1d6cae5 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -365,7 +365,9 @@ def _deserialize(self, value, attr, data, **kwargs): return value def _serialize(self, value, attr, obj, **kwargs): - if obj.by in {SingleFilterBy.PROJECT_ID, SingleFilterBy.RUN_ID}: + if obj.operator == SingleFilterOperator.DEFINED: + field = fields.Boolean() + elif obj.by in {SingleFilterBy.PROJECT_ID, SingleFilterBy.RUN_ID}: field = fields.UUID() elif obj.by == SingleFilterBy.EXPERIMENT_ID: field = fields.Integer() diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index e7a8e1e5..bd6b169e 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -384,6 +384,16 @@ def test_experiment_run_data_schema_multiple(): "value": DELETED_AT_STRING_PYTHON, } +DELETED_FILTER = SingleFilter( + SingleFilterBy.DELETED_AT, None, SingleFilterOperator.DEFINED, True +) +DELETED_FILTER_BODY = { + "by": "deletedAt", + "key": None, + "operator": "defined", + "value": True, +} + TAG_FILTER = SingleFilter( SingleFilterBy.TAG, "tag_key", SingleFilterOperator.EQUAL_TO, "tag_value" ) @@ -488,6 +498,7 @@ def test_experiment_run_data_schema_multiple(): [EXPERIMENT_ID_FILTER, EXPERIMENT_ID_FILTER_BODY], [RUN_ID_FILTER, RUN_ID_FILTER_BODY], [DELETED_AT_FILTER, DELETED_AT_FILTER_BODY], + [DELETED_FILTER, DELETED_FILTER_BODY], [TAG_FILTER, TAG_FILTER_BODY], [PARAM_NUM_FILTER, PARAM_NUM_FILTER_BODY], [PARAM_TEXT_FILTER, PARAM_TEXT_FILTER_BODY], From 499f1dc17d506dfaa6990bb670e79bb6cc0142a4 Mon Sep 17 00:00:00 2001 From: Hailey Fong Date: Mon, 3 Jun 2019 14:48:01 +0100 Subject: [PATCH 39/60] Add parameterised tests for single filter schema and sort schema --- tests/clients/test_experiment.py | 237 +++++++++++++++++++------------ 1 file changed, 144 insertions(+), 93 deletions(-) diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index bd6b169e..15ed62ef 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -60,6 +60,7 @@ Sort, SortBy, SortOrder, + SortSchema, Tag, TagSchema, ) @@ -97,6 +98,7 @@ "deletedAt": DELETED_AT_STRING, } +RUN_ID = uuid4() RUN_STARTED_AT = datetime(2018, 3, 10, 11, 39, 12, 110000, tzinfo=UTC) RUN_STARTED_AT_NO_TIMEZONE = datetime(2018, 3, 10, 11, 39, 12, 110000) RUN_STARTED_AT_STRING_PYTHON = "2018-03-10T11:39:12.110000+00:00" @@ -351,49 +353,6 @@ def test_experiment_run_data_schema_multiple(): "value": str(PROJECT_ID), } -EXPERIMENT_ID_FILTER = SingleFilter( - SingleFilterBy.EXPERIMENT_ID, None, SingleFilterOperator.EQUAL_TO, 1 -) -EXPERIMENT_ID_FILTER_BODY = { - "by": "experimentId", - "key": None, - "operator": "eq", - "value": 1, -} - -RUN_ID_FILTER = SingleFilter( - SingleFilterBy.RUN_ID, - None, - SingleFilterOperator.EQUAL_TO, - EXPERIMENT_RUN_ID, -) -RUN_ID_FILTER_BODY = { - "by": "runId", - "key": None, - "operator": "eq", - "value": str(EXPERIMENT_RUN_ID), -} - -DELETED_AT_FILTER = SingleFilter( - SingleFilterBy.DELETED_AT, None, SingleFilterOperator.EQUAL_TO, DELETED_AT -) -DELETED_AT_FILTER_BODY = { - "by": "deletedAt", - "key": None, - "operator": "eq", - "value": DELETED_AT_STRING_PYTHON, -} - -DELETED_FILTER = SingleFilter( - SingleFilterBy.DELETED_AT, None, SingleFilterOperator.DEFINED, True -) -DELETED_FILTER_BODY = { - "by": "deletedAt", - "key": None, - "operator": "defined", - "value": True, -} - TAG_FILTER = SingleFilter( SingleFilterBy.TAG, "tag_key", SingleFilterOperator.EQUAL_TO, "tag_value" ) @@ -404,19 +363,6 @@ def test_experiment_run_data_schema_multiple(): "value": "tag_value", } -PARAM_NUM_FILTER = SingleFilter( - SingleFilterBy.PARAM, - "param_key", - SingleFilterOperator.GREATER_THAN_OR_EQUAL_TO, - 1.0, -) -PARAM_NUM_FILTER_BODY = { - "by": "param", - "key": "param_key", - "operator": "ge", - "value": 1.0, -} - PARAM_TEXT_FILTER = SingleFilter( SingleFilterBy.PARAM, "param_key", @@ -430,16 +376,6 @@ def test_experiment_run_data_schema_multiple(): "value": "param_value", } -METRIC_FILTER = SingleFilter( - SingleFilterBy.METRIC, "metric_key", SingleFilterOperator.EQUAL_TO, 1.0 -) -METRIC_FILTER_BODY = { - "by": "metric", - "key": "metric_key", - "operator": "eq", - "value": 1.0, -} - AND_FILTER = CompoundFilter( operator=CompoundFilterOperator.AND, conditions=[TAG_FILTER] ) @@ -465,21 +401,9 @@ def test_experiment_run_data_schema_multiple(): RUN_NUMBER_SORT = [Sort(SortBy.RUN_NUMBER, None, SortOrder.ASC)] RUN_NUMBER_SORT_BODY = [{"by": "runNumber", "key": None, "order": "asc"}] -STARTED_AT_SORT = [Sort(SortBy.STARTED_AT, None, SortOrder.DESC)] -STARTED_AT_SORT_BODY = [{"by": "startedAt", "key": None, "order": "desc"}] - DURATION_SORT = [Sort(SortBy.DURATION, None, SortOrder.DESC)] DURATION_SORT_BODY = [{"by": "duration", "key": None, "order": "desc"}] -PARAM_SORT = [Sort(SortBy.PARAM, "param_key", SortOrder.DESC)] -PARAM_SORT_BODY = [{"by": "param", "key": "param_key", "order": "desc"}] - -TAG_SORT = [Sort(SortBy.TAG, "tag_key", SortOrder.DESC)] -TAG_SORT_BODY = [{"by": "tag", "key": "tag_key", "order": "desc"}] - -METRIC_SORT = [Sort(SortBy.METRIC, "metric_key", SortOrder.DESC)] -METRIC_SORT_BODY = [{"by": "metric", "key": "metric_key", "order": "desc"}] - MULTI_SORT = [ Sort(SortBy.PARAM, "param_key", SortOrder.ASC), Sort(SortBy.RUN_NUMBER, None, SortOrder.DESC), @@ -491,32 +415,19 @@ def test_experiment_run_data_schema_multiple(): @pytest.mark.parametrize( - "filter,filter_body", + "filter, filter_body", [ [None, None], [PROJECT_ID_FILTER, PROJECT_ID_FILTER_BODY], - [EXPERIMENT_ID_FILTER, EXPERIMENT_ID_FILTER_BODY], - [RUN_ID_FILTER, RUN_ID_FILTER_BODY], - [DELETED_AT_FILTER, DELETED_AT_FILTER_BODY], - [DELETED_FILTER, DELETED_FILTER_BODY], - [TAG_FILTER, TAG_FILTER_BODY], - [PARAM_NUM_FILTER, PARAM_NUM_FILTER_BODY], - [PARAM_TEXT_FILTER, PARAM_TEXT_FILTER_BODY], - [METRIC_FILTER, METRIC_FILTER_BODY], [AND_FILTER, AND_FILTER_BODY], [OR_FILTER, OR_FILTER_BODY], ], ) @pytest.mark.parametrize( - "psort,psort_body", + "psort, psort_body", [ [None, None], - [RUN_NUMBER_SORT, RUN_NUMBER_SORT_BODY], - [STARTED_AT_SORT, STARTED_AT_SORT_BODY], [DURATION_SORT, DURATION_SORT_BODY], - [PARAM_SORT, PARAM_SORT_BODY], - [TAG_SORT, TAG_SORT_BODY], - [METRIC_SORT, METRIC_SORT_BODY], [MULTI_SORT, MULTI_SORT_BODY], ], ) @@ -531,6 +442,146 @@ def test_query_runs_schema(mocker, filter, psort, filter_body, psort_body): assert data == expected_json +@pytest.mark.parametrize( + "by,key,value,by_body,value_body", + [ + [ + SingleFilterBy.PROJECT_ID, + None, + PROJECT_ID, + "projectId", + str(PROJECT_ID), + ], + [ + SingleFilterBy.EXPERIMENT_ID, + None, + EXPERIMENT_ID, + "experimentId", + EXPERIMENT_ID, + ], + [SingleFilterBy.RUN_ID, None, RUN_ID, "runId", str(RUN_ID)], + [ + SingleFilterBy.DELETED_AT, + None, + DELETED_AT, + "deletedAt", + DELETED_AT_STRING_PYTHON, + ], + [SingleFilterBy.TAG, "tag-key", "tag-value", "tag", "tag-value"], + [ + SingleFilterBy.PARAM, + "param-key", + "param-text-value", + "param", + "param-text-value", + ], + [SingleFilterBy.PARAM, "param-key", 1, "param", 1], + [SingleFilterBy.METRIC, "metric-key", 2.0, "metric", 2.0], + ], +) +@pytest.mark.parametrize( + "operator,operator_body", + [ + [SingleFilterOperator.EQUAL_TO, "eq"], + [SingleFilterOperator.NOT_EQUAL_TO, "ne"], + ], +) +def test_single_filter_schema_equality_operators( + mocker, by, key, value, by_body, value_body, operator, operator_body +): + singleFilterObj = SingleFilter(by, key, operator, value) + expected_json = { + "by": by_body, + "key": key, + "operator": operator_body, + "value": value_body, + } + data = SingleFilterSchema().dump(singleFilterObj) + assert data == expected_json + + +@pytest.mark.parametrize( + "by,key,value,by_body,value_body", + [ + [ + SingleFilterBy.DELETED_AT, + None, + DELETED_AT, + "deletedAt", + DELETED_AT_STRING_PYTHON, + ], + [SingleFilterBy.PARAM, "param-key", 1, "param", 1], + [SingleFilterBy.METRIC, "metric-key", 2.0, "metric", 2.0], + ], +) +@pytest.mark.parametrize( + "operator,operator_body", + [ + [SingleFilterOperator.LESS_THAN, "lt"], + [SingleFilterOperator.LESS_THAN_OR_EQUAL_TO, "le"], + [SingleFilterOperator.GREATER_THAN, "gt"], + [SingleFilterOperator.GREATER_THAN_OR_EQUAL_TO, "ge"], + ], +) +def test_single_filter_schema_relational_operators( + mocker, by, key, value, by_body, value_body, operator, operator_body +): + singleFilterObj = SingleFilter(by, key, operator, value) + expected_json = { + "by": by_body, + "key": key, + "operator": operator_body, + "value": value_body, + } + data = SingleFilterSchema().dump(singleFilterObj) + assert data == expected_json + + +@pytest.mark.parametrize( + "by,key,by_body", + [ + [SingleFilterBy.PROJECT_ID, None, "projectId"], + [SingleFilterBy.EXPERIMENT_ID, None, "experimentId"], + [SingleFilterBy.RUN_ID, None, "runId"], + [SingleFilterBy.DELETED_AT, None, "deletedAt"], + [SingleFilterBy.TAG, "tag-key", "tag"], + [SingleFilterBy.PARAM, "param-key", "param"], + [SingleFilterBy.METRIC, "metric-key", "metric"], + ], +) +def test_single_filter_schema_defined_operator(mocker, by, key, by_body): + singleFilterObj = SingleFilter(by, key, SingleFilterOperator.DEFINED, True) + expected_json = { + "by": by_body, + "key": key, + "operator": "defined", + "value": True, + } + data = SingleFilterSchema().dump(singleFilterObj) + assert data == expected_json + + +@pytest.mark.parametrize( + "by,key,by_body", + [ + [SortBy.STARTED_AT, None, "startedAt"], + [SortBy.RUN_NUMBER, None, "runNumber"], + [SortBy.DURATION, None, "duration"], + [SortBy.TAG, "tag-key", "tag"], + [SortBy.PARAM, "param-key", "param"], + [SortBy.METRIC, "metric-key", "metric"], + ], +) +@pytest.mark.parametrize( + "order, order_body", [[SortOrder.ASC, "asc"], [SortOrder.DESC, "desc"]] +) +def test_sort_schema(mocker, by, key, by_body, order, order_body): + sortObj = Sort(by, key, order) + expected_json = {"by": by_body, "key": key, "order": order_body} + data = SortSchema().dump(sortObj) + assert data == expected_json + + def test_single_filter_value_field_validation(mocker): with pytest.raises(ValidationError): singleFilterObj = SingleFilter( From 2192deaa6146f5e3f104d144f3e351378c436872 Mon Sep 17 00:00:00 2001 From: Andrew Crozier Date: Fri, 7 Jun 2019 13:20:19 +0100 Subject: [PATCH 40/60] A bit of refactoring of the filter serialisation code * Use marshmallow's `field.fail()` mechanism * Map fitler types to field types with dict * Factor out param filter type polymorphism --- faculty/clients/experiment.py | 74 ++++++++++++++++++++++------------- 1 file changed, 47 insertions(+), 27 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index a1d6cae5..71d20c32 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -20,12 +20,6 @@ from faculty.clients.base import BaseClient, BaseSchema, Conflict -_ERROR_MESSAGES = { - "invalid_param_value": "Invalid param value", - "invalid_filter_type": "Invalid filter type", - "invalid_none_filter": "A none filter", -} - class ExperimentNameConflict(Exception): def __init__(self, name): @@ -356,37 +350,61 @@ def make_restore_runs_response(self, data): return RestoreExperimentRunsResponse(**data) +class ParamValueField(fields.Field): + """Field that passes through strings or numbers.""" + + default_error_messages = { + "unsupported_type": "Param values must be of type str, int or float." + } + + def _determine_field(self, value): + if isinstance(value, str): + return fields.String() + elif isinstance(value, int) or isinstance(value, float): + return fields.Number() + else: + self.fail("unsupported_type") + + def _deserialize(self, value, attr, data, **kwargs): + field = self._determine_field(value) + return field._deserialize(value, attr, data, **kwargs) + + def _serialize(self, value, attr, obj, **kwargs): + field = self._determine_field(value) + return field._serialize(value, attr, obj, **kwargs) + + class SingleFilterValueField(fields.Field): """ Field that serialises/deserialises a run filter. """ + default_error_messages = { + "invalid_filter_operator": "Invalid filter operator." + } + + FILTER_BY_FIELD_MAPPING = { + SingleFilterBy.PROJECT_ID: fields.UUID, + SingleFilterBy.RUN_ID: fields.UUID, + SingleFilterBy.EXPERIMENT_ID: fields.Integer, + SingleFilterBy.DELETED_AT: fields.DateTime, + SingleFilterBy.TAG: fields.String, + SingleFilterBy.PARAM: ParamValueField, + SingleFilterBy.METRIC: fields.Number, + } + def _deserialize(self, value, attr, data, **kwargs): return value def _serialize(self, value, attr, obj, **kwargs): if obj.operator == SingleFilterOperator.DEFINED: - field = fields.Boolean() - elif obj.by in {SingleFilterBy.PROJECT_ID, SingleFilterBy.RUN_ID}: - field = fields.UUID() - elif obj.by == SingleFilterBy.EXPERIMENT_ID: - field = fields.Integer() - elif obj.by == SingleFilterBy.DELETED_AT: - field = fields.DateTime() - elif obj.by == SingleFilterBy.TAG: - field = fields.Str() - elif obj.by == SingleFilterBy.PARAM: - if isinstance(obj.value, int) or isinstance(obj.value, float): - field = fields.Number() - elif isinstance(obj.value, str): - field = fields.Str() - else: - raise ValidationError(_ERROR_MESSAGES["invalid_param_value"]) - elif obj.by == SingleFilterBy.METRIC: - field = fields.Float() + field_cls = fields.Boolean else: - raise ValidationError(_ERROR_MESSAGES["invalid_filter_type"]) - return field._serialize(value, attr, obj, **kwargs) + try: + field_cls = self.FILTER_BY_FIELD_MAPPING[obj.by] + except KeyError: + self.fail("invalid_filter_operator") + return field_cls()._serialize(value, attr, obj, **kwargs) class FilterField(fields.Field): @@ -394,6 +412,8 @@ class FilterField(fields.Field): Field that serialises/deserialises a run filter. """ + default_error_messages = {"invalid_none_filter": "A none filter"} + def _deserialize(self, value, attr, data, **kwargs): if value is None: return None @@ -406,7 +426,7 @@ def _serialize(self, value, attr, obj, **kwargs): if value is None and isinstance(obj, QueryRuns): return None elif value is None: - raise ValidationError(_ERROR_MESSAGES["invalid_none_filter"]) + self.fail("invalid_none_filter") if isinstance(value, SingleFilter): return SingleFilterSchema().dump(value) else: From b838919e4b02fe890f1a59ecc0587eaf0c725a87 Mon Sep 17 00:00:00 2001 From: Andrew Crozier Date: Fri, 7 Jun 2019 13:45:48 +0100 Subject: [PATCH 41/60] Use new OptionalField wrapper instead of coupling FilterField implementation to context in which it is used --- faculty/clients/experiment.py | 40 ++++++++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 71d20c32..07c4839f 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -407,30 +407,50 @@ def _serialize(self, value, attr, obj, **kwargs): return field_cls()._serialize(value, attr, obj, **kwargs) +class OptionalField(fields.Field): + """Wrap another field, passing through Nones.""" + + def __init__(self, nested, *args, **kwargs): + self.nested = nested + super().__init__(*args, **kwargs) + + def _deserialize(self, value, *args, **kwargs): + if value is None: + return None + else: + return self.nested._deserialize(value, *args, **kwargs) + + def _serialize(self, value, *args, **kwargs): + if value is None: + return None + else: + return self.nested._serialize(value, *args, **kwargs) + + class FilterField(fields.Field): """ Field that serialises/deserialises a run filter. """ - default_error_messages = {"invalid_none_filter": "A none filter"} + default_error_messages = { + "invalid_filter_type": "Unsupported filter type." + } def _deserialize(self, value, attr, data, **kwargs): - if value is None: - return None - elif isinstance(value, SingleFilter): + # TODO: fix this - the isinstance check won't work as the filter hasn't + # yet been deserialised + if isinstance(value, SingleFilter): return SingleFilterSchema().load(value) else: return CompoundFilterSchema().load(value) def _serialize(self, value, attr, obj, **kwargs): - if value is None and isinstance(obj, QueryRuns): - return None - elif value is None: - self.fail("invalid_none_filter") if isinstance(value, SingleFilter): return SingleFilterSchema().dump(value) - else: + elif isinstance(value, CompoundFilter): return CompoundFilterSchema().dump(value) + else: + self.fail("invalid_filter_type") class SingleFilterSchema(BaseSchema): @@ -452,7 +472,7 @@ class SortSchema(BaseSchema): class QueryRunsSchema(BaseSchema): - filter = FilterField(required=True) + filter = OptionalField(FilterField()) sort = fields.List(fields.Nested(SortSchema)) page = fields.Nested(PageSchema, missing=None) From d547ba2df344a58404587ab7f3864642786e8cdf Mon Sep 17 00:00:00 2001 From: Andrew Crozier Date: Fri, 7 Jun 2019 13:49:01 +0100 Subject: [PATCH 42/60] Remove unneeded _serlalize methods These methods were unused, untested, and in some cases broken. --- faculty/clients/experiment.py | 28 +++++----------------------- 1 file changed, 5 insertions(+), 23 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 07c4839f..a8f19f59 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -350,27 +350,20 @@ def make_restore_runs_response(self, data): return RestoreExperimentRunsResponse(**data) -class ParamValueField(fields.Field): +class ParamFilterValueField(fields.Field): """Field that passes through strings or numbers.""" default_error_messages = { "unsupported_type": "Param values must be of type str, int or float." } - def _determine_field(self, value): + def _serialize(self, value, attr, obj, **kwargs): if isinstance(value, str): - return fields.String() + field = fields.String() elif isinstance(value, int) or isinstance(value, float): - return fields.Number() + field = fields.Number() else: self.fail("unsupported_type") - - def _deserialize(self, value, attr, data, **kwargs): - field = self._determine_field(value) - return field._deserialize(value, attr, data, **kwargs) - - def _serialize(self, value, attr, obj, **kwargs): - field = self._determine_field(value) return field._serialize(value, attr, obj, **kwargs) @@ -389,13 +382,10 @@ class SingleFilterValueField(fields.Field): SingleFilterBy.EXPERIMENT_ID: fields.Integer, SingleFilterBy.DELETED_AT: fields.DateTime, SingleFilterBy.TAG: fields.String, - SingleFilterBy.PARAM: ParamValueField, + SingleFilterBy.PARAM: ParamFilterValueField, SingleFilterBy.METRIC: fields.Number, } - def _deserialize(self, value, attr, data, **kwargs): - return value - def _serialize(self, value, attr, obj, **kwargs): if obj.operator == SingleFilterOperator.DEFINED: field_cls = fields.Boolean @@ -436,14 +426,6 @@ class FilterField(fields.Field): "invalid_filter_type": "Unsupported filter type." } - def _deserialize(self, value, attr, data, **kwargs): - # TODO: fix this - the isinstance check won't work as the filter hasn't - # yet been deserialised - if isinstance(value, SingleFilter): - return SingleFilterSchema().load(value) - else: - return CompoundFilterSchema().load(value) - def _serialize(self, value, attr, obj, **kwargs): if isinstance(value, SingleFilter): return SingleFilterSchema().dump(value) From cdb81e2184ac074d0dde50f147a1b32b6d0089bb Mon Sep 17 00:00:00 2001 From: Hailey Fong Date: Fri, 7 Jun 2019 14:18:17 +0100 Subject: [PATCH 43/60] Remove parameter skip_if in CompoundFilterSchema --- faculty/clients/experiment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index a8f19f59..e9605fef 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -444,7 +444,7 @@ class SingleFilterSchema(BaseSchema): class CompoundFilterSchema(BaseSchema): operator = EnumField(CompoundFilterOperator, by_value=True, required=True) - conditions = fields.List(FilterField(skip_if=None)) + conditions = fields.List(FilterField()) class SortSchema(BaseSchema): From f2cd19a864c61c6b6b0345f039d10c3219932068 Mon Sep 17 00:00:00 2001 From: Hailey Fong Date: Fri, 7 Jun 2019 14:40:00 +0100 Subject: [PATCH 44/60] Remove unused import --- faculty/clients/experiment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index e9605fef..6acface8 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -15,7 +15,7 @@ from collections import namedtuple from enum import Enum -from marshmallow import fields, post_load, ValidationError +from marshmallow import fields, post_load from marshmallow_enum import EnumField from faculty.clients.base import BaseClient, BaseSchema, Conflict From d0d4dc41f885e3e7335ea04d9895a102be7973aa Mon Sep 17 00:00:00 2001 From: Andrew Crozier Date: Fri, 7 Jun 2019 14:26:09 +0100 Subject: [PATCH 45/60] Do some test refactoring --- tests/clients/test_experiment.py | 243 ++++++++++++++++--------------- 1 file changed, 127 insertions(+), 116 deletions(-) diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index 15ed62ef..0c6beb1c 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -417,25 +417,25 @@ def test_experiment_run_data_schema_multiple(): @pytest.mark.parametrize( "filter, filter_body", [ - [None, None], - [PROJECT_ID_FILTER, PROJECT_ID_FILTER_BODY], - [AND_FILTER, AND_FILTER_BODY], - [OR_FILTER, OR_FILTER_BODY], + (None, None), + (PROJECT_ID_FILTER, PROJECT_ID_FILTER_BODY), + (AND_FILTER, AND_FILTER_BODY), + (OR_FILTER, OR_FILTER_BODY), ], ) @pytest.mark.parametrize( - "psort, psort_body", + "sort, sort_body", [ - [None, None], - [DURATION_SORT, DURATION_SORT_BODY], - [MULTI_SORT, MULTI_SORT_BODY], + (None, None), + (DURATION_SORT, DURATION_SORT_BODY), + (MULTI_SORT, MULTI_SORT_BODY), ], ) -def test_query_runs_schema(mocker, filter, psort, filter_body, psort_body): - queryRunsObj = QueryRuns(filter, psort, PAGE) +def test_query_runs_schema(mocker, filter, sort, filter_body, sort_body): + queryRunsObj = QueryRuns(filter, sort, PAGE) expected_json = { "filter": filter_body, - "sort": psort_body, + "sort": sort_body, "page": PAGE_BODY, } data = QueryRunsSchema().dump(queryRunsObj) @@ -443,160 +443,154 @@ def test_query_runs_schema(mocker, filter, psort, filter_body, psort_body): @pytest.mark.parametrize( - "by,key,value,by_body,value_body", + "by, key, value, by_body, value_body", [ - [ + ( SingleFilterBy.PROJECT_ID, None, PROJECT_ID, "projectId", str(PROJECT_ID), - ], - [ + ), + ( SingleFilterBy.EXPERIMENT_ID, None, EXPERIMENT_ID, "experimentId", EXPERIMENT_ID, - ], - [SingleFilterBy.RUN_ID, None, RUN_ID, "runId", str(RUN_ID)], - [ + ), + (SingleFilterBy.RUN_ID, None, RUN_ID, "runId", str(RUN_ID)), + ( SingleFilterBy.DELETED_AT, None, DELETED_AT, "deletedAt", DELETED_AT_STRING_PYTHON, - ], - [SingleFilterBy.TAG, "tag-key", "tag-value", "tag", "tag-value"], - [ + ), + (SingleFilterBy.TAG, "tag-key", "tag-value", "tag", "tag-value"), + ( SingleFilterBy.PARAM, "param-key", "param-text-value", "param", "param-text-value", - ], - [SingleFilterBy.PARAM, "param-key", 1, "param", 1], - [SingleFilterBy.METRIC, "metric-key", 2.0, "metric", 2.0], + ), + (SingleFilterBy.PARAM, "param-key", 1, "param", 1), + (SingleFilterBy.METRIC, "metric-key", 2.0, "metric", 2.0), ], ) @pytest.mark.parametrize( - "operator,operator_body", + "operator, operator_body", [ - [SingleFilterOperator.EQUAL_TO, "eq"], - [SingleFilterOperator.NOT_EQUAL_TO, "ne"], + (SingleFilterOperator.EQUAL_TO, "eq"), + (SingleFilterOperator.NOT_EQUAL_TO, "ne"), ], ) def test_single_filter_schema_equality_operators( - mocker, by, key, value, by_body, value_body, operator, operator_body + by, key, value, by_body, value_body, operator, operator_body ): - singleFilterObj = SingleFilter(by, key, operator, value) + filter = SingleFilter(by, key, operator, value) expected_json = { "by": by_body, "key": key, "operator": operator_body, "value": value_body, } - data = SingleFilterSchema().dump(singleFilterObj) + data = SingleFilterSchema().dump(filter) assert data == expected_json @pytest.mark.parametrize( - "by,key,value,by_body,value_body", + "by, key, value, by_body, value_body", [ - [ + ( SingleFilterBy.DELETED_AT, None, DELETED_AT, "deletedAt", DELETED_AT_STRING_PYTHON, - ], - [SingleFilterBy.PARAM, "param-key", 1, "param", 1], - [SingleFilterBy.METRIC, "metric-key", 2.0, "metric", 2.0], + ), + (SingleFilterBy.PARAM, "param-key", 1, "param", 1), + (SingleFilterBy.METRIC, "metric-key", 2.0, "metric", 2.0), ], ) @pytest.mark.parametrize( - "operator,operator_body", + "operator, operator_body", [ - [SingleFilterOperator.LESS_THAN, "lt"], - [SingleFilterOperator.LESS_THAN_OR_EQUAL_TO, "le"], - [SingleFilterOperator.GREATER_THAN, "gt"], - [SingleFilterOperator.GREATER_THAN_OR_EQUAL_TO, "ge"], + (SingleFilterOperator.LESS_THAN, "lt"), + (SingleFilterOperator.LESS_THAN_OR_EQUAL_TO, "le"), + (SingleFilterOperator.GREATER_THAN, "gt"), + (SingleFilterOperator.GREATER_THAN_OR_EQUAL_TO, "ge"), ], ) def test_single_filter_schema_relational_operators( - mocker, by, key, value, by_body, value_body, operator, operator_body + by, key, value, by_body, value_body, operator, operator_body ): - singleFilterObj = SingleFilter(by, key, operator, value) + filter = SingleFilter(by, key, operator, value) expected_json = { "by": by_body, "key": key, "operator": operator_body, "value": value_body, } - data = SingleFilterSchema().dump(singleFilterObj) + data = SingleFilterSchema().dump(filter) assert data == expected_json @pytest.mark.parametrize( - "by,key,by_body", + "by, key, by_body", [ - [SingleFilterBy.PROJECT_ID, None, "projectId"], - [SingleFilterBy.EXPERIMENT_ID, None, "experimentId"], - [SingleFilterBy.RUN_ID, None, "runId"], - [SingleFilterBy.DELETED_AT, None, "deletedAt"], - [SingleFilterBy.TAG, "tag-key", "tag"], - [SingleFilterBy.PARAM, "param-key", "param"], - [SingleFilterBy.METRIC, "metric-key", "metric"], + (SingleFilterBy.PROJECT_ID, None, "projectId"), + (SingleFilterBy.EXPERIMENT_ID, None, "experimentId"), + (SingleFilterBy.RUN_ID, None, "runId"), + (SingleFilterBy.DELETED_AT, None, "deletedAt"), + (SingleFilterBy.TAG, "tag-key", "tag"), + (SingleFilterBy.PARAM, "param-key", "param"), + (SingleFilterBy.METRIC, "metric-key", "metric"), ], ) -def test_single_filter_schema_defined_operator(mocker, by, key, by_body): - singleFilterObj = SingleFilter(by, key, SingleFilterOperator.DEFINED, True) +def test_single_filter_schema_defined_operator(by, key, by_body): + filter = SingleFilter(by, key, SingleFilterOperator.DEFINED, True) expected_json = { "by": by_body, "key": key, "operator": "defined", "value": True, } - data = SingleFilterSchema().dump(singleFilterObj) + data = SingleFilterSchema().dump(filter) assert data == expected_json @pytest.mark.parametrize( - "by,key,by_body", + "by, value, message", [ - [SortBy.STARTED_AT, None, "startedAt"], - [SortBy.RUN_NUMBER, None, "runNumber"], - [SortBy.DURATION, None, "duration"], - [SortBy.TAG, "tag-key", "tag"], - [SortBy.PARAM, "param-key", "param"], - [SortBy.METRIC, "metric-key", "metric"], + (SingleFilterBy.PROJECT_ID, "invalid", "Not a valid UUID."), + (SingleFilterBy.EXPERIMENT_ID, "string", "Not a valid integer."), + (SingleFilterBy.RUN_ID, "invalid", "Not a valid UUID."), + ( + SingleFilterBy.DELETED_AT, + "invalid", + "cannot be formatted as a datetime", + ), + (SingleFilterBy.METRIC, "invalid", "Not a valid number."), + (SingleFilterBy.PARAM, None, "must be of type str, int or float"), ], ) -@pytest.mark.parametrize( - "order, order_body", [[SortOrder.ASC, "asc"], [SortOrder.DESC, "desc"]] -) -def test_sort_schema(mocker, by, key, by_body, order, order_body): - sortObj = Sort(by, key, order) - expected_json = {"by": by_body, "key": key, "order": order_body} - data = SortSchema().dump(sortObj) - assert data == expected_json - - -def test_single_filter_value_field_validation(mocker): - with pytest.raises(ValidationError): - singleFilterObj = SingleFilter( - SingleFilterBy.PARAM, - "param_key", - SingleFilterOperator.GREATER_THAN, - True, - ) - SingleFilterSchema().dump(singleFilterObj) +def test_single_filter_invalid_value(by, value, message): + filter = SingleFilter( + by, + "key" if by.needs_key() else None, + SingleFilterOperator.EQUAL_TO, + value, + ) + with pytest.raises(ValidationError, match=message): + SingleFilterSchema().dump(filter) @pytest.mark.parametrize( - "by,key,op,value,err_msg", + "by, key, operator, value, message", [ - [ + ( SingleFilterBy.PROJECT_ID, "invalid_key", SingleFilterOperator.EQUAL_TO, @@ -604,8 +598,8 @@ def test_single_filter_value_field_validation(mocker): "key must be none for filter type {}".format( SingleFilterBy.PROJECT_ID ), - ], - [ + ), + ( SingleFilterBy.TAG, None, SingleFilterOperator.EQUAL_TO, @@ -613,8 +607,8 @@ def test_single_filter_value_field_validation(mocker): "key must not be none for filter type {}".format( SingleFilterBy.TAG ), - ], - [ + ), + ( SingleFilterBy.PARAM, "param_key", SingleFilterOperator.GREATER_THAN, @@ -622,43 +616,60 @@ def test_single_filter_value_field_validation(mocker): "invalid type {}. Value has to be either an int or a float".format( type("param_value") ), - ], + ), ], ) -def test_single_filter_validation(mocker, by, key, op, value, err_msg): - with pytest.raises(ValueError, match=err_msg): - SingleFilter(by, key, op, value) - - -def test_compound_filter_validation(mocker): +def test_single_filter_validation(by, key, operator, value, message): + with pytest.raises(ValueError, match=message): + SingleFilter(by, key, operator, value) + + +def test_compound_filter_validation(): + filter = CompoundFilter( + operator=CompoundFilterOperator.OR, + conditions=[ + SingleFilter( + SingleFilterBy.TAG, + "tag_key", + SingleFilterOperator.EQUAL_TO, + "tag_value", + ), + None, + ], + ) + run_query = QueryRuns(filter, None, None) with pytest.raises(ValidationError): - filter = CompoundFilter( - operator=CompoundFilterOperator.OR, - conditions=[ - SingleFilter( - SingleFilterBy.TAG, - "tag_key", - SingleFilterOperator.EQUAL_TO, - "tag_value", - ), - None, - ], - ) + QueryRunsSchema().dump(run_query) + - queryRunsObj = QueryRuns(filter, None, None) - QueryRunsSchema().dump(queryRunsObj) +@pytest.mark.parametrize( + "by, key, by_body", + [ + (SortBy.STARTED_AT, None, "startedAt"), + (SortBy.RUN_NUMBER, None, "runNumber"), + (SortBy.DURATION, None, "duration"), + (SortBy.TAG, "tag-key", "tag"), + (SortBy.PARAM, "param-key", "param"), + (SortBy.METRIC, "metric-key", "metric"), + ], +) +@pytest.mark.parametrize( + "order, order_body", [(SortOrder.ASC, "asc"), (SortOrder.DESC, "desc")] +) +def test_sort_schema(by, key, by_body, order, order_body): + sort = Sort(by, key, order) + expected_json = {"by": by_body, "key": key, "order": order_body} + data = SortSchema().dump(sort) + assert data == expected_json -def test_sort_validation(mocker): - with pytest.raises( - ValueError, - match="key must be none for sort type {}".format(SortBy.RUN_NUMBER), - ): +def test_sort_validate_no_key(): + with pytest.raises(ValueError, match="key must be none"): Sort(SortBy.RUN_NUMBER, "invalid_number", SortOrder.ASC) - with pytest.raises( - ValueError, - match="key must not be none for sort type {}".format(SortBy.TAG), - ): + + +def test_sort_validate_has_key(): + with pytest.raises(ValueError, match="key must not be none"): Sort(SortBy.TAG, None, SortOrder.ASC) From 0e7277e93d53fb36c80894d76c03dee65fe99422 Mon Sep 17 00:00:00 2001 From: Andrew Crozier Date: Tue, 11 Jun 2019 12:06:01 +0100 Subject: [PATCH 46/60] Minor cleanups --- faculty/clients/experiment.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 80859b2f..1e177e40 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -117,13 +117,12 @@ def __new__(cls, by, key, operator, value): raise ValueError("key must be none for filter type {}".format(by)) elif ( by == SingleFilterBy.PARAM - and operator.is_numeric_operator() + and operator.is_numeric() and not (isinstance(value, float) or isinstance(value, int)) ): raise ValueError( ( - "invalid type {}. Value has to be either an int " - + "or a float" + "invalid type {}. Value has to be either an int or a float" ).format(type(value)) ) return super(SingleFilter, cls).__new__(cls, by, key, operator, value) @@ -141,7 +140,7 @@ class SingleFilterOperator(Enum): GREATER_THAN = "gt" GREATER_THAN_OR_EQUAL_TO = "ge" - def is_numeric_operator(self): + def is_numeric(self): return self in { SingleFilterOperator.LESS_THAN, SingleFilterOperator.LESS_THAN_OR_EQUAL_TO, From 2ab75fbe242ec5ca7004efe7d1190f2c8c72f415 Mon Sep 17 00:00:00 2001 From: Andrew Crozier Date: Tue, 11 Jun 2019 13:04:52 +0100 Subject: [PATCH 47/60] WIP: Different filter implementations --- faculty/clients/experiment.py | 229 ++++++++++++++++++---------------- setup.py | 1 + 2 files changed, 125 insertions(+), 105 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 1e177e40..84850361 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -15,8 +15,9 @@ from collections import namedtuple from enum import Enum -from marshmallow import fields, post_load +from marshmallow import fields, post_load, post_dump from marshmallow_enum import EnumField +from marshmallow_oneofschema import OneOfSchema from faculty.clients.base import BaseClient, BaseSchema, Conflict @@ -100,38 +101,14 @@ class ExperimentRunStatus(Enum): "RestoreExperimentRunsResponse", ["restored_run_ids", "conflicted_run_ids"] ) -_SingleFilter = namedtuple("_SingleFilter", ["by", "key", "operator", "value"]) - -class SingleFilter(_SingleFilter): - def __new__(cls, by, key, operator, value): - if isinstance(by, SingleFilterBy) and by.needs_key() and key is None: - raise ValueError( - "key must not be none for filter type {}".format(by) - ) - elif ( - isinstance(by, SingleFilterBy) - and not by.needs_key() - and key is not None - ): - raise ValueError("key must be none for filter type {}".format(by)) - elif ( - by == SingleFilterBy.PARAM - and operator.is_numeric() - and not (isinstance(value, float) or isinstance(value, int)) - ): - raise ValueError( - ( - "invalid type {}. Value has to be either an int or a float" - ).format(type(value)) - ) - return super(SingleFilter, cls).__new__(cls, by, key, operator, value) - - -CompoundFilter = namedtuple("CompoundFilter", ["operator", "conditions"]) +class DiscreteOperator(Enum): + DEFINED = "defined" + EQUAL_TO = "eq" + NOT_EQUAL_TO = "ne" -class SingleFilterOperator(Enum): +class ContinuousOperator(Enum): DEFINED = "defined" EQUAL_TO = "eq" NOT_EQUAL_TO = "ne" @@ -140,35 +117,21 @@ class SingleFilterOperator(Enum): GREATER_THAN = "gt" GREATER_THAN_OR_EQUAL_TO = "ge" - def is_numeric(self): - return self in { - SingleFilterOperator.LESS_THAN, - SingleFilterOperator.LESS_THAN_OR_EQUAL_TO, - SingleFilterOperator.GREATER_THAN, - SingleFilterOperator.GREATER_THAN_OR_EQUAL_TO, - } +class LogicalOperator(Enum): + AND = "and" + OR = "or" -class SingleFilterBy(Enum): - PROJECT_ID = "projectId" - EXPERIMENT_ID = "experimentId" - RUN_ID = "runId" - DELETED_AT = "deletedAt" - TAG = "tag" - PARAM = "param" - METRIC = "metric" - - def needs_key(self): - return self in { - SingleFilterBy.TAG, - SingleFilterBy.PARAM, - SingleFilterBy.METRIC, - } +ProjectIdFilter = namedtuple("ProjectIdFilter", ["operator", "value"]) +ExperimentIdFilter = namedtuple("ExperimentIdFilter", ["operator", "value"]) +RunIdFilter = namedtuple("RunIdFilter", ["operator", "value"]) +DeletedAtFilter = namedtuple("DeletedAtFilter", ["operator", "value"]) +TagFilter = namedtuple("TagFilter", ["key", "operator", "value"]) +ParamFilter = namedtuple("ParamFilter", ["key", "operator", "value"]) +MetricFilter = namedtuple("MetricFilter", ["key", "operator", "value"]) -class CompoundFilterOperator(Enum): - AND = "and" - OR = "or" +CompoundFilter = namedtuple("CompoundFilter", ["operator", "conditions"]) _Sort = namedtuple("_Sort", ["by", "key", "order"]) @@ -375,34 +338,34 @@ def _serialize(self, value, attr, obj, **kwargs): return field._serialize(value, attr, obj, **kwargs) -class SingleFilterValueField(fields.Field): - """ - Field that serialises/deserialises a run filter. - """ - - default_error_messages = { - "invalid_filter_operator": "Invalid filter operator." - } - - FILTER_BY_FIELD_MAPPING = { - SingleFilterBy.PROJECT_ID: fields.UUID, - SingleFilterBy.RUN_ID: fields.UUID, - SingleFilterBy.EXPERIMENT_ID: fields.Integer, - SingleFilterBy.DELETED_AT: fields.DateTime, - SingleFilterBy.TAG: fields.String, - SingleFilterBy.PARAM: ParamFilterValueField, - SingleFilterBy.METRIC: fields.Number, - } - - def _serialize(self, value, attr, obj, **kwargs): - if obj.operator == SingleFilterOperator.DEFINED: - field_cls = fields.Boolean - else: - try: - field_cls = self.FILTER_BY_FIELD_MAPPING[obj.by] - except KeyError: - self.fail("invalid_filter_operator") - return field_cls()._serialize(value, attr, obj, **kwargs) +# class SingleFilterValueField(fields.Field): +# """ +# Field that serialises/deserialises a run filter. +# """ + +# default_error_messages = { +# "invalid_filter_operator": "Invalid filter operator." +# } + +# FILTER_BY_FIELD_MAPPING = { +# SingleFilterBy.PROJECT_ID: fields.UUID, +# SingleFilterBy.RUN_ID: fields.UUID, +# SingleFilterBy.EXPERIMENT_ID: fields.Integer, +# SingleFilterBy.DELETED_AT: fields.DateTime, +# SingleFilterBy.TAG: fields.String, +# SingleFilterBy.PARAM: ParamFilterValueField, +# SingleFilterBy.METRIC: fields.Number, +# } + +# def _serialize(self, value, attr, obj, **kwargs): +# if obj.operator == SingleFilterOperator.DEFINED: +# field_cls = fields.Boolean +# else: +# try: +# field_cls = self.FILTER_BY_FIELD_MAPPING[obj.by] +# except KeyError: +# self.fail("invalid_filter_operator") +# return field_cls()._serialize(value, attr, obj, **kwargs) class OptionalField(fields.Field): @@ -425,34 +388,90 @@ def _serialize(self, value, *args, **kwargs): return self.nested._serialize(value, *args, **kwargs) -class FilterField(fields.Field): - """ - Field that serialises/deserialises a run filter. - """ +# class FilterField(fields.Field): +# """ +# Field that serialises/deserialises a run filter. +# """ - default_error_messages = { - "invalid_filter_type": "Unsupported filter type." - } +# default_error_messages = { +# "invalid_filter_type": "Unsupported filter type." +# } - def _serialize(self, value, attr, obj, **kwargs): - if isinstance(value, SingleFilter): - return SingleFilterSchema().dump(value) - elif isinstance(value, CompoundFilter): - return CompoundFilterSchema().dump(value) - else: - self.fail("invalid_filter_type") +# def _serialize(self, value, attr, obj, **kwargs): +# if isinstance(value, SingleFilter): +# return SingleFilterSchema().dump(value) +# elif isinstance(value, CompoundFilter): +# return CompoundFilterSchema().dump(value) +# else: +# self.fail("invalid_filter_type") -class SingleFilterSchema(BaseSchema): - by = EnumField(SingleFilterBy, by_value=True, required=True) - key = fields.String() - operator = EnumField(SingleFilterOperator, by_value=True, required=True) - value = SingleFilterValueField(required=True) +class ProjectIdFilterSchema(BaseSchema): + operator = EnumField(DiscreteOperator, by_value=True, required=True) + value = fields.UUID(required=True) + by = fields.Constant("projectId", dump_only=True) + + +class ExperimentIdFilterSchema(BaseSchema): + operator = EnumField(DiscreteOperator, by_value=True, required=True) + value = fields.Integer(required=True) + by = fields.Constant("experimentId", dump_only=True) + + +class RunIdFilterSchema(BaseSchema): + operator = EnumField(DiscreteOperator, by_value=True, required=True) + value = fields.UUID(required=True) + by = fields.Constant("runId", dump_only=True) + + +class DeletedAtFilterSchema(BaseSchema): + operator = EnumField(ContinuousOperator, by_value=True, required=True) + value = fields.DateTime(required=True) + by = fields.Constant("deletedAt", dump_only=True) + + +class TagFilterSchema(BaseSchema): + key = fields.String(required=True) + operator = EnumField(DiscreteOperator, by_value=True, required=True) + value = fields.String(required=True) + by = fields.Constant("tag", dump_only=True) + + +class ParamFilterSchema(BaseSchema): + key = fields.String(required=True) + operator = EnumField(ContinuousOperator, by_value=True, required=True) + value = ParamFilterValueField(required=True) + by = fields.Constant("param", dump_only=True) + + +class MetricFilterSchema(BaseSchema): + key = fields.String(required=True) + operator = EnumField(ContinuousOperator, by_value=True, required=True) + value = fields.Float(required=True) + by = fields.Constant("metric", dump_only=True) class CompoundFilterSchema(BaseSchema): - operator = EnumField(CompoundFilterOperator, by_value=True, required=True) - conditions = fields.List(FilterField()) + operator = EnumField(LogicalOperator, by_value=True, required=True) + conditions = fields.List(fields.Nested("FilterSchema")) + + +class FilterSchema(OneOfSchema): + type_schemas = { + "ProjectIdFilter": ProjectIdFilterSchema, + "ExperimentIdFilter": ExperimentIdFilterSchema, + "RunIdFilter": RunIdFilterSchema, + "DeletedAtFilter": DeletedAtFilterSchema, + "TagFilter": TagFilterSchema, + "ParamFilter": ParamFilterSchema, + "MetricFilter": MetricFilterSchema, + "CompoundFilter": CompoundFilterSchema, + } + + def dump(self, *args, **kwargs): + data = super(FilterSchema, self).dump(*args, **kwargs) + # Remove the type field added by marshmallow-oneofschema + return {k: v for k, v in data.items() if k != "type"} class SortSchema(BaseSchema): @@ -462,7 +481,7 @@ class SortSchema(BaseSchema): class QueryRunsSchema(BaseSchema): - filter = OptionalField(FilterField()) + filter = OptionalField(fields.Nested(FilterSchema)) sort = fields.List(fields.Nested(SortSchema)) page = fields.Nested(PageSchema, missing=None) diff --git a/setup.py b/setup.py index a8772681..1e427358 100644 --- a/setup.py +++ b/setup.py @@ -35,6 +35,7 @@ # compatible version of python-dateutil is available "marshmallow[reco]>=3.0.0rc3", "marshmallow_enum", + "marshmallow-oneofschema", "boto3", "botocore", ], From 51b448cb03db127a20627dfcadc3bb6c13e1e71c Mon Sep 17 00:00:00 2001 From: Andrew Crozier Date: Tue, 11 Jun 2019 16:53:08 +0100 Subject: [PATCH 48/60] Complete filter and sort object refactor and tidy tests --- faculty/clients/experiment.py | 271 ++++++----- tests/clients/test_experiment.py | 752 ++++++++++++++----------------- 2 files changed, 472 insertions(+), 551 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 84850361..cd243803 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -15,7 +15,7 @@ from collections import namedtuple from enum import Enum -from marshmallow import fields, post_load, post_dump +from marshmallow import fields, post_load, pre_dump, ValidationError from marshmallow_enum import EnumField from marshmallow_oneofschema import OneOfSchema @@ -102,13 +102,7 @@ class ExperimentRunStatus(Enum): ) -class DiscreteOperator(Enum): - DEFINED = "defined" - EQUAL_TO = "eq" - NOT_EQUAL_TO = "ne" - - -class ContinuousOperator(Enum): +class ComparisonOperator(Enum): DEFINED = "defined" EQUAL_TO = "eq" NOT_EQUAL_TO = "ne" @@ -133,31 +127,12 @@ class LogicalOperator(Enum): CompoundFilter = namedtuple("CompoundFilter", ["operator", "conditions"]) - -_Sort = namedtuple("_Sort", ["by", "key", "order"]) - - -class Sort(_Sort): - def __new__(cls, by, key, order): - if by.needs_key() and key is None: - raise ValueError( - "key must not be none for sort type {}".format(by) - ) - elif not by.needs_key() and key is not None: - raise ValueError("key must be none for sort type {}".format(by)) - return super(Sort, cls).__new__(cls, by, key, order) - - -class SortBy(Enum): - STARTED_AT = "startedAt" - RUN_NUMBER = "runNumber" - DURATION = "duration" - TAG = "tag" - PARAM = "param" - METRIC = "metric" - - def needs_key(self): - return self in {SortBy.TAG, SortBy.PARAM, SortBy.METRIC} +StartedAtSort = namedtuple("StartedAtSort", ["order"]) +RunNumberSort = namedtuple("RunNumberSort", ["order"]) +DurationSort = namedtuple("DurationSort", ["order"]) +TagSort = namedtuple("TagSort", ["key", "order"]) +ParamSort = namedtuple("ParamSort", ["key", "order"]) +MetricSort = namedtuple("MetricSort", ["key", "order"]) class SortOrder(Enum): @@ -165,7 +140,7 @@ class SortOrder(Enum): DESC = "desc" -QueryRuns = namedtuple("QueryRuns", ["filter", "sort", "page"]) +RunQuery = namedtuple("RunQuery", ["filter", "sort", "page"]) class PageSchema(BaseSchema): @@ -338,36 +313,6 @@ def _serialize(self, value, attr, obj, **kwargs): return field._serialize(value, attr, obj, **kwargs) -# class SingleFilterValueField(fields.Field): -# """ -# Field that serialises/deserialises a run filter. -# """ - -# default_error_messages = { -# "invalid_filter_operator": "Invalid filter operator." -# } - -# FILTER_BY_FIELD_MAPPING = { -# SingleFilterBy.PROJECT_ID: fields.UUID, -# SingleFilterBy.RUN_ID: fields.UUID, -# SingleFilterBy.EXPERIMENT_ID: fields.Integer, -# SingleFilterBy.DELETED_AT: fields.DateTime, -# SingleFilterBy.TAG: fields.String, -# SingleFilterBy.PARAM: ParamFilterValueField, -# SingleFilterBy.METRIC: fields.Number, -# } - -# def _serialize(self, value, attr, obj, **kwargs): -# if obj.operator == SingleFilterOperator.DEFINED: -# field_cls = fields.Boolean -# else: -# try: -# field_cls = self.FILTER_BY_FIELD_MAPPING[obj.by] -# except KeyError: -# self.fail("invalid_filter_operator") -# return field_cls()._serialize(value, attr, obj, **kwargs) - - class OptionalField(fields.Field): """Wrap another field, passing through Nones.""" @@ -388,75 +333,112 @@ def _serialize(self, value, *args, **kwargs): return self.nested._serialize(value, *args, **kwargs) -# class FilterField(fields.Field): -# """ -# Field that serialises/deserialises a run filter. -# """ +class FilterValueField(fields.Field): + def __init__(self, other_field_type, *args, **kwargs): + self.other_field_type = other_field_type + super(FilterValueField, self).__init__(*args, **kwargs) -# default_error_messages = { -# "invalid_filter_type": "Unsupported filter type." -# } + def _serialize(self, value, attr, obj, **kwargs): + if obj.operator == ComparisonOperator.DEFINED: + field_cls = fields.Boolean + else: + field_cls = self.other_field_type + return field_cls()._serialize(value, attr, obj, **kwargs) -# def _serialize(self, value, attr, obj, **kwargs): -# if isinstance(value, SingleFilter): -# return SingleFilterSchema().dump(value) -# elif isinstance(value, CompoundFilter): -# return CompoundFilterSchema().dump(value) -# else: -# self.fail("invalid_filter_type") + +def _validate_discrete(operator): + if operator not in { + ComparisonOperator.DEFINED, + ComparisonOperator.EQUAL_TO, + ComparisonOperator.NOT_EQUAL_TO, + }: + raise ValidationError({"operator": "Not a discrete operator."}) class ProjectIdFilterSchema(BaseSchema): - operator = EnumField(DiscreteOperator, by_value=True, required=True) - value = fields.UUID(required=True) + operator = EnumField(ComparisonOperator, by_value=True) + value = FilterValueField(fields.UUID) by = fields.Constant("projectId", dump_only=True) + @pre_dump + def check_operator(self, obj): + _validate_discrete(obj.operator) + return obj + class ExperimentIdFilterSchema(BaseSchema): - operator = EnumField(DiscreteOperator, by_value=True, required=True) - value = fields.Integer(required=True) + operator = EnumField(ComparisonOperator, by_value=True) + value = FilterValueField(fields.Integer) by = fields.Constant("experimentId", dump_only=True) + @pre_dump + def check_operator(self, obj): + _validate_discrete(obj.operator) + return obj + class RunIdFilterSchema(BaseSchema): - operator = EnumField(DiscreteOperator, by_value=True, required=True) - value = fields.UUID(required=True) + operator = EnumField(ComparisonOperator, by_value=True) + value = FilterValueField(fields.UUID) by = fields.Constant("runId", dump_only=True) + @pre_dump + def check_operator(self, obj): + _validate_discrete(obj.operator) + return obj + class DeletedAtFilterSchema(BaseSchema): - operator = EnumField(ContinuousOperator, by_value=True, required=True) - value = fields.DateTime(required=True) + operator = EnumField(ComparisonOperator, by_value=True) + value = FilterValueField(fields.DateTime) by = fields.Constant("deletedAt", dump_only=True) class TagFilterSchema(BaseSchema): - key = fields.String(required=True) - operator = EnumField(DiscreteOperator, by_value=True, required=True) - value = fields.String(required=True) + key = fields.String() + operator = EnumField(ComparisonOperator, by_value=True) + value = FilterValueField(fields.String) by = fields.Constant("tag", dump_only=True) + @pre_dump + def check_operator(self, obj): + _validate_discrete(obj.operator) + return obj + class ParamFilterSchema(BaseSchema): - key = fields.String(required=True) - operator = EnumField(ContinuousOperator, by_value=True, required=True) - value = ParamFilterValueField(required=True) + key = fields.String() + operator = EnumField(ComparisonOperator, by_value=True) + value = FilterValueField(ParamFilterValueField) by = fields.Constant("param", dump_only=True) + @pre_dump + def check_operator(self, obj): + if isinstance(obj.value, str): + _validate_discrete(obj.operator) + return obj + class MetricFilterSchema(BaseSchema): - key = fields.String(required=True) - operator = EnumField(ContinuousOperator, by_value=True, required=True) - value = fields.Float(required=True) + key = fields.String() + operator = EnumField(ComparisonOperator, by_value=True) + value = FilterValueField(fields.Float) by = fields.Constant("metric", dump_only=True) class CompoundFilterSchema(BaseSchema): - operator = EnumField(LogicalOperator, by_value=True, required=True) + operator = EnumField(LogicalOperator, by_value=True) conditions = fields.List(fields.Nested("FilterSchema")) -class FilterSchema(OneOfSchema): +class OneOfSchemaWithoutType(OneOfSchema): + def dump(self, *args, **kwargs): + data = super(OneOfSchemaWithoutType, self).dump(*args, **kwargs) + # Remove the type field added by marshmallow-oneofschema + return {k: v for k, v in data.items() if k != "type"} + + +class FilterSchema(OneOfSchemaWithoutType): type_schemas = { "ProjectIdFilter": ProjectIdFilterSchema, "ExperimentIdFilter": ExperimentIdFilterSchema, @@ -468,19 +450,52 @@ class FilterSchema(OneOfSchema): "CompoundFilter": CompoundFilterSchema, } - def dump(self, *args, **kwargs): - data = super(FilterSchema, self).dump(*args, **kwargs) - # Remove the type field added by marshmallow-oneofschema - return {k: v for k, v in data.items() if k != "type"} + +class StartedAtSortSchema(BaseSchema): + order = EnumField(SortOrder, by_value=True) + by = fields.Constant("startedAt", dump_only=True) + + +class RunNumberSortSchema(BaseSchema): + order = EnumField(SortOrder, by_value=True) + by = fields.Constant("runNumber", dump_only=True) + + +class DurationSortSchema(BaseSchema): + order = EnumField(SortOrder, by_value=True) + by = fields.Constant("duration", dump_only=True) + + +class TagSortSchema(BaseSchema): + key = fields.String() + order = EnumField(SortOrder, by_value=True) + by = fields.Constant("tag", dump_only=True) + + +class ParamSortSchema(BaseSchema): + key = fields.String() + order = EnumField(SortOrder, by_value=True) + by = fields.Constant("param", dump_only=True) -class SortSchema(BaseSchema): - by = EnumField(SortBy, by_value=True, required=True) +class MetricSortSchema(BaseSchema): key = fields.String() - order = EnumField(SortOrder, by_value=True, required=True) + order = EnumField(SortOrder, by_value=True) + by = fields.Constant("metric", dump_only=True) -class QueryRunsSchema(BaseSchema): +class SortSchema(OneOfSchemaWithoutType): + type_schemas = { + "StartedAtSort": StartedAtSortSchema, + "RunNumberSort": RunNumberSortSchema, + "DurationSort": DurationSortSchema, + "TagSort": TagSortSchema, + "ParamSort": ParamSortSchema, + "MetricSort": MetricSortSchema, + } + + +class RunQuerySchema(BaseSchema): filter = OptionalField(fields.Nested(FilterSchema)) sort = fields.List(fields.Nested(SortSchema)) page = fields.Nested(PageSchema, missing=None) @@ -785,29 +800,21 @@ def list_runs( ), ) experiment_id_filters = [ - SingleFilter( - SingleFilterBy.EXPERIMENT_ID, - None, - SingleFilterOperator.EQUAL_TO, - experiment_id, - ) + ExperimentIdFilter(ComparisonOperator.EQUAL_TO, experiment_id) for experiment_id in experiment_ids ] experiment_ids_filter = CompoundFilter( - CompoundFilterOperator.OR, experiment_id_filters + LogicalOperator.OR, experiment_id_filters ) if lifecycle_stage is not None: - lifecycle_filter = SingleFilter( - SingleFilterBy.DELETED_AT, - None, - SingleFilterOperator.DEFINED, + lifecycle_filter = DeletedAtFilter( + ComparisonOperator.DEFINED, lifecycle_stage == LifecycleStage.DELETED, ) if experiment_ids_filter is not None and lifecycle_filter is not None: filter = CompoundFilter( - CompoundFilterOperator.AND, - [experiment_ids_filter, lifecycle_filter], + LogicalOperator.AND, [experiment_ids_filter, lifecycle_filter] ) elif experiment_ids_filter is not None: filter = experiment_ids_filter @@ -856,7 +863,7 @@ def query_runs( page = None if start is not None and limit is not None: page = Page(start, limit) - payload = QueryRunsSchema().dump(QueryRuns(filter, sort, page)) + payload = RunQuerySchema().dump(RunQuery(filter, sort, page)) return self._post( endpoint, ListExperimentRunsResponseSchema(), json=payload ) @@ -975,18 +982,11 @@ def delete_runs(self, project_id, run_ids=None): ) else: run_id_filters = [ - SingleFilter( - SingleFilterBy.RUN_ID, - None, - SingleFilterOperator.EQUAL_TO, - run_id, - ) + RunIdFilter(ComparisonOperator.EQUAL_TO, run_id) for run_id in run_ids ] - run_ids_filter = CompoundFilter( - CompoundFilterOperator.OR, run_id_filters - ) - payload = {"filter": run_ids_filter} + filter = CompoundFilter(LogicalOperator.OR, run_id_filters) + payload = {"filter": FilterSchema().dump(filter)} return self._post( endpoint, DeleteExperimentRunsResponseSchema(), json=payload @@ -1020,18 +1020,11 @@ def restore_runs(self, project_id, run_ids=None): ) else: run_id_filters = [ - SingleFilter( - SingleFilterBy.RUN_ID, - None, - SingleFilterOperator.EQUAL_TO, - run_id, - ) + RunIdFilter(ComparisonOperator.EQUAL_TO, run_id) for run_id in run_ids ] - run_ids_filter = CompoundFilter( - CompoundFilterOperator.OR, run_id_filters - ) - payload = {"filter": run_ids_filter} + filter = CompoundFilter(LogicalOperator.OR, run_id_filters) + payload = {"filter": FilterSchema().dump(filter)} return self._post( endpoint, RestoreExperimentRunsResponseSchema(), json=payload diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index d37578f4..d6db90b0 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -22,49 +22,58 @@ from faculty.clients.base import Conflict from faculty.clients.experiment import ( - CreateRunSchema, + ComparisonOperator, CompoundFilter, - CompoundFilterOperator, + CreateRunSchema, DeleteExperimentRunsResponse, DeleteExperimentRunsResponseSchema, + DeletedAtFilter, + DurationSort, Experiment, ExperimentClient, - ExperimentNameConflict, ExperimentDeleted, + ExperimentIdFilter, + ExperimentNameConflict, ExperimentRun, ExperimentRunDataSchema, ExperimentRunSchema, ExperimentRunStatus, ExperimentSchema, + FilterSchema, LifecycleStage, ListExperimentRunsResponse, ListExperimentRunsResponseSchema, + LogicalOperator, Metric, MetricDataPoint, - MetricSchema, + MetricFilter, MetricHistory, MetricHistorySchema, + MetricSchema, + MetricSort, Page, PageSchema, Pagination, PaginationSchema, Param, ParamConflict, + ParamFilter, ParamSchema, - QueryRuns, - QueryRunsSchema, + ParamSort, + ProjectIdFilter, RestoreExperimentRunsResponse, RestoreExperimentRunsResponseSchema, - SingleFilter, - SingleFilterBy, - SingleFilterOperator, - SingleFilterSchema, - Sort, - SortBy, + RunIdFilter, + RunNumberSort, + RunQuery, + RunQuerySchema, SortOrder, SortSchema, + StartedAtSort, Tag, + TagFilter, TagSchema, + TagSort, ) PROJECT_ID = uuid4() @@ -348,334 +357,315 @@ def test_experiment_run_data_schema_multiple(): assert data == {"tags": [TAG_BODY, OTHER_TAG_BODY]} -PROJECT_ID_FILTER = SingleFilter( - SingleFilterBy.PROJECT_ID, None, SingleFilterOperator.EQUAL_TO, PROJECT_ID -) +PROJECT_ID_FILTER = ProjectIdFilter(ComparisonOperator.EQUAL_TO, PROJECT_ID) PROJECT_ID_FILTER_BODY = { "by": "projectId", - "key": None, "operator": "eq", "value": str(PROJECT_ID), } -TAG_FILTER = SingleFilter( - SingleFilterBy.TAG, "tag_key", SingleFilterOperator.EQUAL_TO, "tag_value" -) +TAG_FILTER = TagFilter("tag-key", ComparisonOperator.EQUAL_TO, "tag-value") TAG_FILTER_BODY = { "by": "tag", - "key": "tag_key", + "key": "tag-key", "operator": "eq", - "value": "tag_value", + "value": "tag-value", } -PARAM_TEXT_FILTER = SingleFilter( - SingleFilterBy.PARAM, - "param_key", - SingleFilterOperator.EQUAL_TO, - "param_value", -) -PARAM_TEXT_FILTER_BODY = { - "by": "param", - "key": "param_key", - "operator": "eq", - "value": "param_value", -} -AND_FILTER = CompoundFilter( - operator=CompoundFilterOperator.AND, conditions=[TAG_FILTER] -) -AND_FILTER_BODY = {"operator": "and", "conditions": [TAG_FILTER_BODY]} - -OR_FILTER = CompoundFilter( - operator=CompoundFilterOperator.OR, - conditions=[ - TAG_FILTER, - CompoundFilter( - operator=CompoundFilterOperator.AND, conditions=[PARAM_TEXT_FILTER] - ), - ], -) -OR_FILTER_BODY = { - "operator": "or", - "conditions": [ - TAG_FILTER_BODY, - {"operator": "and", "conditions": [PARAM_TEXT_FILTER_BODY]}, - ], -} +DEFINED_TEST_CASES = [ + (ComparisonOperator.DEFINED, False, "defined", False), + (ComparisonOperator.DEFINED, True, "defined", True), + (ComparisonOperator.DEFINED, 0, "defined", False), + (ComparisonOperator.DEFINED, 1, "defined", True), +] -RUN_NUMBER_SORT = [Sort(SortBy.RUN_NUMBER, None, SortOrder.ASC)] -RUN_NUMBER_SORT_BODY = [{"by": "runNumber", "key": None, "order": "asc"}] -DURATION_SORT = [Sort(SortBy.DURATION, None, SortOrder.DESC)] -DURATION_SORT_BODY = [{"by": "duration", "key": None, "order": "desc"}] +def discrete_test_cases(value, expected): + return DEFINED_TEST_CASES + [ + (ComparisonOperator.EQUAL_TO, value, "eq", expected), + (ComparisonOperator.NOT_EQUAL_TO, value, "ne", expected), + ] -MULTI_SORT = [ - Sort(SortBy.PARAM, "param_key", SortOrder.ASC), - Sort(SortBy.RUN_NUMBER, None, SortOrder.DESC), -] -MULTI_SORT_BODY = [ - {"by": "param", "key": "param_key", "order": "asc"}, - {"by": "runNumber", "key": None, "order": "desc"}, -] + +def continuous_test_cases(value, expected): + return discrete_test_cases(value, expected) + [ + (ComparisonOperator.GREATER_THAN, value, "gt", expected), + (ComparisonOperator.GREATER_THAN_OR_EQUAL_TO, value, "ge", expected), + (ComparisonOperator.LESS_THAN, value, "lt", expected), + (ComparisonOperator.LESS_THAN_OR_EQUAL_TO, value, "le", expected), + ] @pytest.mark.parametrize( - "filter, filter_body", - [ - (None, None), - (PROJECT_ID_FILTER, PROJECT_ID_FILTER_BODY), - (AND_FILTER, AND_FILTER_BODY), - (OR_FILTER, OR_FILTER_BODY), - ], + "operator, value, expected_operator, expected_value", + discrete_test_cases(PROJECT_ID, str(PROJECT_ID)), ) +def test_filter_schema_project_id( + operator, value, expected_operator, expected_value +): + filter = ProjectIdFilter(operator, value) + data = FilterSchema().dump(filter) + assert data == { + "by": "projectId", + "operator": expected_operator, + "value": expected_value, + } + + @pytest.mark.parametrize( - "sort, sort_body", - [ - (None, None), - (DURATION_SORT, DURATION_SORT_BODY), - (MULTI_SORT, MULTI_SORT_BODY), - ], + "operator, value, expected_operator, expected_value", + discrete_test_cases(EXPERIMENT_ID, EXPERIMENT_ID), ) -def test_query_runs_schema(mocker, filter, sort, filter_body, sort_body): - queryRunsObj = QueryRuns(filter, sort, PAGE) - expected_json = { - "filter": filter_body, - "sort": sort_body, - "page": PAGE_BODY, +def test_filter_schema_experiment_id( + operator, value, expected_operator, expected_value +): + filter = ExperimentIdFilter(operator, value) + data = FilterSchema().dump(filter) + assert data == { + "by": "experimentId", + "operator": expected_operator, + "value": expected_value, } - data = QueryRunsSchema().dump(queryRunsObj) - assert data == expected_json @pytest.mark.parametrize( - "by, key, value, by_body, value_body", - [ - ( - SingleFilterBy.PROJECT_ID, - None, - PROJECT_ID, - "projectId", - str(PROJECT_ID), - ), - ( - SingleFilterBy.EXPERIMENT_ID, - None, - EXPERIMENT_ID, - "experimentId", - EXPERIMENT_ID, - ), - (SingleFilterBy.RUN_ID, None, RUN_ID, "runId", str(RUN_ID)), - ( - SingleFilterBy.DELETED_AT, - None, - DELETED_AT, - "deletedAt", - DELETED_AT_STRING_PYTHON, - ), - (SingleFilterBy.TAG, "tag-key", "tag-value", "tag", "tag-value"), - ( - SingleFilterBy.PARAM, - "param-key", - "param-text-value", - "param", - "param-text-value", - ), - (SingleFilterBy.PARAM, "param-key", 1, "param", 1), - (SingleFilterBy.METRIC, "metric-key", 2.0, "metric", 2.0), - ], + "operator, value, expected_operator, expected_value", + discrete_test_cases(RUN_ID, str(RUN_ID)), ) +def test_filter_schema_run_id( + operator, value, expected_operator, expected_value +): + filter = RunIdFilter(operator, value) + data = FilterSchema().dump(filter) + assert data == { + "by": "runId", + "operator": expected_operator, + "value": expected_value, + } + + @pytest.mark.parametrize( - "operator, operator_body", - [ - (SingleFilterOperator.EQUAL_TO, "eq"), - (SingleFilterOperator.NOT_EQUAL_TO, "ne"), - ], + "operator, value, expected_operator, expected_value", + continuous_test_cases(DELETED_AT, DELETED_AT_STRING_PYTHON), ) -def test_single_filter_schema_equality_operators( - by, key, value, by_body, value_body, operator, operator_body +def test_filter_schema_deleted_at( + operator, value, expected_operator, expected_value ): - filter = SingleFilter(by, key, operator, value) - expected_json = { - "by": by_body, - "key": key, - "operator": operator_body, - "value": value_body, + filter = DeletedAtFilter(operator, value) + data = FilterSchema().dump(filter) + assert data == { + "by": "deletedAt", + "operator": expected_operator, + "value": expected_value, } - data = SingleFilterSchema().dump(filter) - assert data == expected_json @pytest.mark.parametrize( - "by, key, value, by_body, value_body", - [ - ( - SingleFilterBy.DELETED_AT, - None, - DELETED_AT, - "deletedAt", - DELETED_AT_STRING_PYTHON, - ), - (SingleFilterBy.PARAM, "param-key", 1, "param", 1), - (SingleFilterBy.METRIC, "metric-key", 2.0, "metric", 2.0), - ], + "operator, value, expected_operator, expected_value", + discrete_test_cases("tag-value", "tag-value"), ) +def test_filter_schema_tag(operator, value, expected_operator, expected_value): + filter = TagFilter("tag-key", operator, value) + data = FilterSchema().dump(filter) + assert data == { + "by": "tag", + "key": "tag-key", + "operator": expected_operator, + "value": expected_value, + } + + @pytest.mark.parametrize( - "operator, operator_body", - [ - (SingleFilterOperator.LESS_THAN, "lt"), - (SingleFilterOperator.LESS_THAN_OR_EQUAL_TO, "le"), - (SingleFilterOperator.GREATER_THAN, "gt"), - (SingleFilterOperator.GREATER_THAN_OR_EQUAL_TO, "ge"), - ], + "operator, value, expected_operator, expected_value", + discrete_test_cases("param-value", "param-value") + + continuous_test_cases(123.2, 123.2), ) -def test_single_filter_schema_relational_operators( - by, key, value, by_body, value_body, operator, operator_body +def test_filter_schema_param( + operator, value, expected_operator, expected_value ): - filter = SingleFilter(by, key, operator, value) - expected_json = { - "by": by_body, - "key": key, - "operator": operator_body, - "value": value_body, + filter = ParamFilter("param-key", operator, value) + data = FilterSchema().dump(filter) + assert data == { + "by": "param", + "key": "param-key", + "operator": expected_operator, + "value": expected_value, } - data = SingleFilterSchema().dump(filter) - assert data == expected_json @pytest.mark.parametrize( - "by, key, by_body", - [ - (SingleFilterBy.PROJECT_ID, None, "projectId"), - (SingleFilterBy.EXPERIMENT_ID, None, "experimentId"), - (SingleFilterBy.RUN_ID, None, "runId"), - (SingleFilterBy.DELETED_AT, None, "deletedAt"), - (SingleFilterBy.TAG, "tag-key", "tag"), - (SingleFilterBy.PARAM, "param-key", "param"), - (SingleFilterBy.METRIC, "metric-key", "metric"), - ], + "operator, value, expected_operator, expected_value", + continuous_test_cases(45.6, 45.6), ) -def test_single_filter_schema_defined_operator(by, key, by_body): - filter = SingleFilter(by, key, SingleFilterOperator.DEFINED, True) - expected_json = { - "by": by_body, - "key": key, - "operator": "defined", - "value": True, +def test_filter_schema_metric( + operator, value, expected_operator, expected_value +): + filter = MetricFilter("metric-key", operator, value) + data = FilterSchema().dump(filter) + assert data == { + "by": "metric", + "key": "metric-key", + "operator": expected_operator, + "value": expected_value, } - data = SingleFilterSchema().dump(filter) - assert data == expected_json @pytest.mark.parametrize( - "by, value, message", + "filter_type", + [ProjectIdFilter, ExperimentIdFilter, RunIdFilter, DeletedAtFilter], +) +def test_filter_schema_invalid_value_no_key(filter_type): + filter = filter_type(ComparisonOperator.EQUAL_TO, "invalid") + with pytest.raises(ValidationError): + FilterSchema().dump(filter) + + +@pytest.mark.parametrize( + "filter_type, value", [(ParamFilter, None), (MetricFilter, "invalid")] +) +def test_filter_schema_invalid_value_with_key(filter_type, value): + filter = filter_type("key", ComparisonOperator.EQUAL_TO, value) + with pytest.raises(ValidationError): + FilterSchema().dump(filter) + + +@pytest.mark.parametrize( + "filter_type, value", [ - (SingleFilterBy.PROJECT_ID, "invalid", "Not a valid UUID."), - (SingleFilterBy.EXPERIMENT_ID, "string", "Not a valid integer."), - (SingleFilterBy.RUN_ID, "invalid", "Not a valid UUID."), - ( - SingleFilterBy.DELETED_AT, - "invalid", - "cannot be formatted as a datetime", - ), - (SingleFilterBy.METRIC, "invalid", "Not a valid number."), - (SingleFilterBy.PARAM, None, "must be of type str, int or float"), + (ProjectIdFilter, PROJECT_ID), + (ExperimentIdFilter, EXPERIMENT_ID), + (RunIdFilter, RUN_ID), ], ) -def test_single_filter_invalid_value(by, value, message): - filter = SingleFilter( - by, - "key" if by.needs_key() else None, - SingleFilterOperator.EQUAL_TO, - value, - ) - with pytest.raises(ValidationError, match=message): - SingleFilterSchema().dump(filter) +@pytest.mark.parametrize( + "operator", + [ + ComparisonOperator.GREATER_THAN, + ComparisonOperator.GREATER_THAN_OR_EQUAL_TO, + ComparisonOperator.LESS_THAN, + ComparisonOperator.LESS_THAN_OR_EQUAL_TO, + ], +) +def test_filter_schema_invalid_operator_no_key(filter_type, value, operator): + filter = filter_type(operator, value) + with pytest.raises(ValidationError, match="Not a discrete operator"): + FilterSchema().dump(filter) @pytest.mark.parametrize( - "by, key, operator, value, message", + "filter_type, value", + [(TagFilter, "tag-value"), (ParamFilter, "param-string-value")], +) +@pytest.mark.parametrize( + "operator", [ - ( - SingleFilterBy.PROJECT_ID, - "invalid_key", - SingleFilterOperator.EQUAL_TO, - PROJECT_ID, - "key must be none for filter type {}".format( - SingleFilterBy.PROJECT_ID - ), - ), - ( - SingleFilterBy.TAG, - None, - SingleFilterOperator.EQUAL_TO, - "tag_value", - "key must not be none for filter type {}".format( - SingleFilterBy.TAG - ), - ), - ( - SingleFilterBy.PARAM, - "param_key", - SingleFilterOperator.GREATER_THAN, - "param_value", - "invalid type {}. Value has to be either an int or a float".format( - type("param_value") - ), - ), + ComparisonOperator.GREATER_THAN, + ComparisonOperator.GREATER_THAN_OR_EQUAL_TO, + ComparisonOperator.LESS_THAN, + ComparisonOperator.LESS_THAN_OR_EQUAL_TO, ], ) -def test_single_filter_validation(by, key, operator, value, message): - with pytest.raises(ValueError, match=message): - SingleFilter(by, key, operator, value) +def test_filter_schema_invalid_operator_with_key(filter_type, value, operator): + filter = filter_type("key", operator, value) + with pytest.raises(ValidationError, match="Not a discrete operator"): + FilterSchema().dump(filter) -def test_compound_filter_validation(): +@pytest.mark.parametrize( + "operator, expected_operator", + [(LogicalOperator.AND, "and"), (LogicalOperator.OR, "or")], +) +def test_filter_schema_compound(operator, expected_operator): + filter = CompoundFilter(operator, [PROJECT_ID_FILTER, TAG_FILTER]) + data = FilterSchema().dump(filter) + assert data == { + "operator": expected_operator, + "conditions": [PROJECT_ID_FILTER_BODY, TAG_FILTER_BODY], + } + + +def test_filter_schema_nested(): filter = CompoundFilter( - operator=CompoundFilterOperator.OR, - conditions=[ - SingleFilter( - SingleFilterBy.TAG, - "tag_key", - SingleFilterOperator.EQUAL_TO, - "tag_value", + LogicalOperator.AND, + [ + CompoundFilter( + LogicalOperator.AND, [PROJECT_ID_FILTER, TAG_FILTER] + ), + CompoundFilter( + LogicalOperator.OR, [TAG_FILTER, PROJECT_ID_FILTER] ), - None, ], ) - run_query = QueryRuns(filter, None, None) - with pytest.raises(ValidationError): - QueryRunsSchema().dump(run_query) + data = FilterSchema().dump(filter) + assert data == { + "operator": "and", + "conditions": [ + { + "operator": "and", + "conditions": [PROJECT_ID_FILTER_BODY, TAG_FILTER_BODY], + }, + { + "operator": "or", + "conditions": [TAG_FILTER_BODY, PROJECT_ID_FILTER_BODY], + }, + ], + } @pytest.mark.parametrize( - "by, key, by_body", + "sort_type, by", [ - (SortBy.STARTED_AT, None, "startedAt"), - (SortBy.RUN_NUMBER, None, "runNumber"), - (SortBy.DURATION, None, "duration"), - (SortBy.TAG, "tag-key", "tag"), - (SortBy.PARAM, "param-key", "param"), - (SortBy.METRIC, "metric-key", "metric"), + (StartedAtSort, "startedAt"), + (RunNumberSort, "runNumber"), + (DurationSort, "duration"), ], ) @pytest.mark.parametrize( - "order, order_body", [(SortOrder.ASC, "asc"), (SortOrder.DESC, "desc")] + "order, expected_order", [(SortOrder.ASC, "asc"), (SortOrder.DESC, "desc")] ) -def test_sort_schema(by, key, by_body, order, order_body): - sort = Sort(by, key, order) - expected_json = {"by": by_body, "key": key, "order": order_body} +def test_sort_schema_no_tag(sort_type, by, order, expected_order): + sort = sort_type(order) data = SortSchema().dump(sort) - assert data == expected_json + assert data == {"by": by, "order": expected_order} -def test_sort_validate_no_key(): - with pytest.raises(ValueError, match="key must be none"): - Sort(SortBy.RUN_NUMBER, "invalid_number", SortOrder.ASC) +@pytest.mark.parametrize( + "sort_type, by", + [(TagSort, "tag"), (ParamSort, "param"), (MetricSort, "metric")], +) +@pytest.mark.parametrize( + "order, expected_order", [(SortOrder.ASC, "asc"), (SortOrder.DESC, "desc")] +) +def test_sort_schema_with_tag(sort_type, by, order, expected_order): + sort = sort_type("sort-key", order) + data = SortSchema().dump(sort) + assert data == {"by": by, "key": "sort-key", "order": expected_order} + +def test_run_query_schema(mocker): + mocker.patch.object(FilterSchema, "dump") + mocker.patch.object(SortSchema, "dump") + mocker.patch.object(PageSchema, "dump") -def test_sort_validate_has_key(): - with pytest.raises(ValueError, match="key must not be none"): - Sort(SortBy.TAG, None, SortOrder.ASC) + filter = mocker.Mock() + sorts = [mocker.Mock(), mocker.Mock()] + page = mocker.Mock() + + run_query = RunQuery(filter, sorts, page) + data = RunQuerySchema().dump(run_query) + + assert data == { + "filter": FilterSchema.dump.return_value, + "sort": [SortSchema.dump.return_value, SortSchema.dump.return_value], + "page": PageSchema.dump.return_value, + } + + +def test_run_query_schema_defaults(): + run_query = RunQuery(None, None, None) + data = RunQuerySchema().dump(run_query) + assert data == {"filter": None, "sort": None, "page": None} @pytest.mark.parametrize("description", [None, "experiment description"]) @@ -900,11 +890,16 @@ def test_list_runs_schema(mocker): assert data == LIST_EXPERIMENT_RUNS_RESPONSE -def test_page_schema(): +def test_page_schema_load(): data = PageSchema().load(PAGE_BODY) assert data == PAGE +def test_page_schema_dump(): + data = PageSchema().dump(PAGE) + assert data == PAGE_BODY + + def test_pagination_schema(): data = PaginationSchema().load(PAGINATION_BODY) assert data == PAGINATION @@ -942,85 +937,46 @@ def test_restore_experiment_runs_response_schema_invalid(mocker): RestoreExperimentRunsResponseSchema().load({}) -def test_experiment_client_list_runs_all(mocker): - mocker.patch.object( - ExperimentClient, "_post", return_value=LIST_EXPERIMENT_RUNS_RESPONSE - ) - response_schema_mock = mocker.patch( - "faculty.clients.experiment.ListExperimentRunsResponseSchema" - ) - request_schema_mock = mocker.patch( - "faculty.clients.experiment.QueryRunsSchema" - ) - dump_mock = request_schema_mock.return_value.dump - - client = ExperimentClient(mocker.Mock()) - list_result = client.list_runs(PROJECT_ID) - assert list_result == LIST_EXPERIMENT_RUNS_RESPONSE - - response_schema_mock.assert_called_once_with() - ExperimentClient._post.assert_called_once_with( - "/project/{}/run/query".format(PROJECT_ID), - response_schema_mock.return_value, - json=dump_mock.return_value, - ) - - -def test_experiment_client_list_runs_experiments_filter(mocker): - mocker.patch.object( - ExperimentClient, "_post", return_value=LIST_EXPERIMENT_RUNS_RESPONSE - ) - response_schema_mock = mocker.patch( - "faculty.clients.experiment.ListExperimentRunsResponseSchema" - ) - request_schema_mock = mocker.patch( - "faculty.clients.experiment.QueryRunsSchema" - ) - dump_mock = request_schema_mock.return_value.dump +def test_experiment_client_list_runs(mocker): + mocker.patch.object(ExperimentClient, "query_runs") client = ExperimentClient(mocker.Mock()) - list_result = client.list_runs(PROJECT_ID, experiment_ids=[123, 456]) - assert list_result == LIST_EXPERIMENT_RUNS_RESPONSE - - response_schema_mock.assert_called_once_with() - ExperimentClient._post.assert_called_once_with( - "/project/{}/run/query".format(PROJECT_ID), - response_schema_mock.return_value, - json=dump_mock.return_value, - ) - - -def test_experiment_client_list_runs_page(mocker): - mocker.patch.object( - ExperimentClient, "_post", return_value=LIST_EXPERIMENT_RUNS_RESPONSE - ) - response_schema_mock = mocker.patch( - "faculty.clients.experiment.ListExperimentRunsResponseSchema" + response = client.list_runs( + PROJECT_ID, + experiment_ids=[123, 456], + lifecycle_stage=LifecycleStage.DELETED, + start=20, + limit=10, + ) + + assert response == ExperimentClient.query_runs.return_value + expected_filter = CompoundFilter( + LogicalOperator.AND, + [ + CompoundFilter( + LogicalOperator.OR, + [ + ExperimentIdFilter(ComparisonOperator.EQUAL_TO, 123), + ExperimentIdFilter(ComparisonOperator.EQUAL_TO, 456), + ], + ), + DeletedAtFilter(ComparisonOperator.DEFINED, True), + ], ) - request_schema_mock = mocker.patch( - "faculty.clients.experiment.QueryRunsSchema" + ExperimentClient.query_runs.assert_called_once_with( + PROJECT_ID, expected_filter, None, 20, 10 ) - dump_mock = request_schema_mock.return_value.dump - client = ExperimentClient(mocker.Mock()) - list_result = client.list_runs(PROJECT_ID, start=20, limit=10) - assert list_result == LIST_EXPERIMENT_RUNS_RESPONSE - - response_schema_mock.assert_called_once_with() - ExperimentClient._post.assert_called_once_with( - "/project/{}/run/query".format(PROJECT_ID), - response_schema_mock.return_value, - json=dump_mock.return_value, - ) +def test_experiment_client_list_runs_defaults(mocker): + mocker.patch.object(ExperimentClient, "query_runs") -def test_experiment_client_list_runs_experiments_filter_empty(mocker): client = ExperimentClient(mocker.Mock()) - list_result = client.list_runs(PROJECT_ID, experiment_ids=[]) + response = client.list_runs(PROJECT_ID) - assert list_result == ListExperimentRunsResponse( - runs=[], - pagination=Pagination(start=0, size=0, previous=None, next=None), + assert response == ExperimentClient.query_runs.return_value + ExperimentClient.query_runs.assert_called_once_with( + PROJECT_ID, None, None, None, None ) @@ -1031,32 +987,26 @@ def test_experiment_client_query_runs(mocker): response_schema_mock = mocker.patch( "faculty.clients.experiment.ListExperimentRunsResponseSchema" ) - request_schema_mock = mocker.patch( - "faculty.clients.experiment.QueryRunsSchema" - ) - dump_mock = request_schema_mock.return_value.dump + request_dump_mock = mocker.patch.object(RunQuerySchema, "dump") - test_filter = SingleFilter( - SingleFilterBy.EXPERIMENT_ID, None, SingleFilterOperator.EQUAL_TO, "2" - ) - test_sort = [Sort(SortBy.METRIC, "metric_key", SortOrder.ASC)] + filter = mocker.Mock() + sort = mocker.Mock() client = ExperimentClient(mocker.Mock()) list_result = client.query_runs( - PROJECT_ID, filter=test_filter, sort=test_sort, start=20, limit=10 + PROJECT_ID, filter, sort, start=20, limit=10 ) assert list_result == LIST_EXPERIMENT_RUNS_RESPONSE - request_schema_mock.assert_called_once_with() - dump_mock.assert_called_once_with( - QueryRuns(test_filter, test_sort, Page(20, 10)) + request_dump_mock.assert_called_once_with( + RunQuery(filter, sort, Page(20, 10)) ) response_schema_mock.assert_called_once_with() ExperimentClient._post.assert_called_once_with( "/project/{}/run/query".format(PROJECT_ID), response_schema_mock.return_value, - json=dump_mock.return_value, + json=request_dump_mock.return_value, ) @@ -1279,41 +1229,28 @@ def test_delete_runs(mocker): mocker.patch.object( ExperimentClient, "_post", return_value=DELETE_EXPERIMENT_RUNS_RESPONSE ) - schema_mock = mocker.patch( + response_schema_mock = mocker.patch( "faculty.clients.experiment.DeleteExperimentRunsResponseSchema" ) + filter_dump_mock = mocker.patch.object(FilterSchema, "dump") run_ids = [uuid4(), uuid4()] client = ExperimentClient(mocker.Mock()) - assert ( - client.delete_runs(PROJECT_ID, run_ids) - == DELETE_EXPERIMENT_RUNS_RESPONSE - ) - - expected_payload = { - "filter": CompoundFilter( - CompoundFilterOperator.OR, - [ - SingleFilter( - SingleFilterBy.RUN_ID, - None, - SingleFilterOperator.EQUAL_TO, - run_ids[0], - ), - SingleFilter( - SingleFilterBy.RUN_ID, - None, - SingleFilterOperator.EQUAL_TO, - run_ids[1], - ), - ], - ) - } - + response = client.delete_runs(PROJECT_ID, run_ids) + + assert response == DELETE_EXPERIMENT_RUNS_RESPONSE + expected_filter = CompoundFilter( + LogicalOperator.OR, + [ + RunIdFilter(ComparisonOperator.EQUAL_TO, run_ids[0]), + RunIdFilter(ComparisonOperator.EQUAL_TO, run_ids[1]), + ], + ) + filter_dump_mock.assert_called_once_with(expected_filter) ExperimentClient._post.assert_called_once_with( "/project/{}/run/delete/query".format(PROJECT_ID), - schema_mock.return_value, - json=expected_payload, + response_schema_mock.return_value, + json={"filter": filter_dump_mock.return_value}, ) @@ -1334,13 +1271,15 @@ def test_delete_runs_no_run_ids(mocker): def test_delete_runs_empty_list(mocker): + mocker.patch.object(ExperimentClient, "_post") + client = ExperimentClient(mocker.Mock()) + response = client.delete_runs(PROJECT_ID, run_ids=[]) - assert client.delete_runs( - PROJECT_ID, run_ids=[] - ) == DeleteExperimentRunsResponse( + assert response == DeleteExperimentRunsResponse( deleted_run_ids=[], conflicted_run_ids=[] ) + ExperimentClient._post.assert_not_called() def test_restore_runs(mocker): @@ -1349,41 +1288,28 @@ def test_restore_runs(mocker): "_post", return_value=RESTORE_EXPERIMENT_RUNS_RESPONSE, ) - schema_mock = mocker.patch( + response_schema_mock = mocker.patch( "faculty.clients.experiment.RestoreExperimentRunsResponseSchema" ) + filter_dump_mock = mocker.patch.object(FilterSchema, "dump") run_ids = [uuid4(), uuid4()] client = ExperimentClient(mocker.Mock()) - assert ( - client.restore_runs(PROJECT_ID, run_ids) - == RESTORE_EXPERIMENT_RUNS_RESPONSE - ) - - expected_payload = { - "filter": CompoundFilter( - CompoundFilterOperator.OR, - [ - SingleFilter( - SingleFilterBy.RUN_ID, - None, - SingleFilterOperator.EQUAL_TO, - run_ids[0], - ), - SingleFilter( - SingleFilterBy.RUN_ID, - None, - SingleFilterOperator.EQUAL_TO, - run_ids[1], - ), - ], - ) - } - + response = client.restore_runs(PROJECT_ID, run_ids) + + assert response == RESTORE_EXPERIMENT_RUNS_RESPONSE + expected_filter = CompoundFilter( + LogicalOperator.OR, + [ + RunIdFilter(ComparisonOperator.EQUAL_TO, run_ids[0]), + RunIdFilter(ComparisonOperator.EQUAL_TO, run_ids[1]), + ], + ) + filter_dump_mock.assert_called_once_with(expected_filter) ExperimentClient._post.assert_called_once_with( "/project/{}/run/restore/query".format(PROJECT_ID), - schema_mock.return_value, - json=expected_payload, + response_schema_mock.return_value, + json={"filter": filter_dump_mock.return_value}, ) @@ -1404,10 +1330,12 @@ def test_restore_runs_no_run_ids(mocker): def test_restore_runs_empty_list(mocker): + mocker.patch.object(ExperimentClient, "_post") + client = ExperimentClient(mocker.Mock()) + response = client.restore_runs(PROJECT_ID, run_ids=[]) - assert client.restore_runs( - PROJECT_ID, run_ids=[] - ) == RestoreExperimentRunsResponse( + assert response == RestoreExperimentRunsResponse( restored_run_ids=[], conflicted_run_ids=[] ) + ExperimentClient._post.assert_not_called() From f38e3b56ef799697de1c949ffbe48f2528cd08c8 Mon Sep 17 00:00:00 2001 From: Andrew Crozier Date: Tue, 11 Jun 2019 17:15:33 +0100 Subject: [PATCH 49/60] Make experiment client module a package --- faculty/clients/{experiment.py => experiment/__init__.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename faculty/clients/{experiment.py => experiment/__init__.py} (100%) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment/__init__.py similarity index 100% rename from faculty/clients/experiment.py rename to faculty/clients/experiment/__init__.py From e57711d5eda1b531ff9cc729f7df285c4636adcd Mon Sep 17 00:00:00 2001 From: Andrew Crozier Date: Tue, 11 Jun 2019 17:22:01 +0100 Subject: [PATCH 50/60] Mirror experiment client package structure in tests --- tests/clients/experiment/__init__.py | 0 tests/clients/{test_experiment.py => experiment/test_init.py} | 0 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/clients/experiment/__init__.py rename tests/clients/{test_experiment.py => experiment/test_init.py} (100%) diff --git a/tests/clients/experiment/__init__.py b/tests/clients/experiment/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/clients/test_experiment.py b/tests/clients/experiment/test_init.py similarity index 100% rename from tests/clients/test_experiment.py rename to tests/clients/experiment/test_init.py From e198ddbdd975137639445c9e109f96a07d359996 Mon Sep 17 00:00:00 2001 From: Andrew Crozier Date: Wed, 12 Jun 2019 13:16:24 +0100 Subject: [PATCH 51/60] Move models and schema to their own modules --- faculty/clients/experiment/__init__.py | 524 +------------ faculty/clients/experiment/models.py | 127 +++ faculty/clients/experiment/schemas.py | 412 ++++++++++ tests/clients/experiment/test_init.py | 958 +++-------------------- tests/clients/experiment/test_schemas.py | 728 +++++++++++++++++ 5 files changed, 1409 insertions(+), 1340 deletions(-) create mode 100644 faculty/clients/experiment/models.py create mode 100644 faculty/clients/experiment/schemas.py create mode 100644 tests/clients/experiment/test_schemas.py diff --git a/faculty/clients/experiment/__init__.py b/faculty/clients/experiment/__init__.py index cd243803..798a0154 100644 --- a/faculty/clients/experiment/__init__.py +++ b/faculty/clients/experiment/__init__.py @@ -12,14 +12,37 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple -from enum import Enum - -from marshmallow import fields, post_load, pre_dump, ValidationError -from marshmallow_enum import EnumField -from marshmallow_oneofschema import OneOfSchema - -from faculty.clients.base import BaseClient, BaseSchema, Conflict +from faculty.clients.base import BaseClient, Conflict + +from faculty.clients.experiment.models import ( + ComparisonOperator, + CompoundFilter, + DeleteExperimentRunsResponse, + DeletedAtFilter, + ExperimentIdFilter, + LifecycleStage, + ListExperimentRunsResponse, + LogicalOperator, + Metric, + Page, + Pagination, + RestoreExperimentRunsResponse, + RunIdFilter, + RunQuery, +) +from faculty.clients.experiment.schemas import ( + CreateRunSchema, + DeleteExperimentRunsResponseSchema, + ExperimentRunDataSchema, + ExperimentRunInfoSchema, + ExperimentRunSchema, + ExperimentSchema, + FilterSchema, + ListExperimentRunsResponseSchema, + MetricHistorySchema, + RestoreExperimentRunsResponseSchema, + RunQuerySchema, +) class ExperimentNameConflict(Exception): @@ -44,491 +67,6 @@ def __init__(self, message, experiment_id): self.experiment_id = experiment_id -class ExperimentRunStatus(Enum): - RUNNING = "running" - FINISHED = "finished" - FAILED = "failed" - SCHEDULED = "scheduled" - KILLED = "killed" - - -Experiment = namedtuple( - "Experiment", - [ - "id", - "name", - "description", - "artifact_location", - "created_at", - "last_updated_at", - "deleted_at", - ], -) - - -ExperimentRun = namedtuple( - "ExperimentRun", - [ - "id", - "run_number", - "experiment_id", - "name", - "parent_run_id", - "artifact_location", - "status", - "started_at", - "ended_at", - "deleted_at", - "tags", - "params", - "metrics", - ], -) - -Metric = namedtuple("Metric", ["key", "value", "timestamp", "step"]) -Param = namedtuple("Param", ["key", "value"]) -Tag = namedtuple("Tag", ["key", "value"]) - -Page = namedtuple("Page", ["start", "limit"]) -Pagination = namedtuple("Pagination", ["start", "size", "previous", "next"]) -ListExperimentRunsResponse = namedtuple( - "ListExperimentRunsResponse", ["runs", "pagination"] -) -DeleteExperimentRunsResponse = namedtuple( - "DeleteExperimentRunsResponse", ["deleted_run_ids", "conflicted_run_ids"] -) -RestoreExperimentRunsResponse = namedtuple( - "RestoreExperimentRunsResponse", ["restored_run_ids", "conflicted_run_ids"] -) - - -class ComparisonOperator(Enum): - DEFINED = "defined" - EQUAL_TO = "eq" - NOT_EQUAL_TO = "ne" - LESS_THAN = "lt" - LESS_THAN_OR_EQUAL_TO = "le" - GREATER_THAN = "gt" - GREATER_THAN_OR_EQUAL_TO = "ge" - - -class LogicalOperator(Enum): - AND = "and" - OR = "or" - - -ProjectIdFilter = namedtuple("ProjectIdFilter", ["operator", "value"]) -ExperimentIdFilter = namedtuple("ExperimentIdFilter", ["operator", "value"]) -RunIdFilter = namedtuple("RunIdFilter", ["operator", "value"]) -DeletedAtFilter = namedtuple("DeletedAtFilter", ["operator", "value"]) -TagFilter = namedtuple("TagFilter", ["key", "operator", "value"]) -ParamFilter = namedtuple("ParamFilter", ["key", "operator", "value"]) -MetricFilter = namedtuple("MetricFilter", ["key", "operator", "value"]) - -CompoundFilter = namedtuple("CompoundFilter", ["operator", "conditions"]) - -StartedAtSort = namedtuple("StartedAtSort", ["order"]) -RunNumberSort = namedtuple("RunNumberSort", ["order"]) -DurationSort = namedtuple("DurationSort", ["order"]) -TagSort = namedtuple("TagSort", ["key", "order"]) -ParamSort = namedtuple("ParamSort", ["key", "order"]) -MetricSort = namedtuple("MetricSort", ["key", "order"]) - - -class SortOrder(Enum): - ASC = "asc" - DESC = "desc" - - -RunQuery = namedtuple("RunQuery", ["filter", "sort", "page"]) - - -class PageSchema(BaseSchema): - start = fields.Integer(required=True) - limit = fields.Integer(required=True) - - @post_load - def make_page(self, data): - return Page(**data) - - -MetricDataPoint = namedtuple("Metric", ["value", "timestamp", "step"]) - -MetricHistory = namedtuple( - "MetricHistory", ["original_size", "subsampled", "key", "history"] -) - - -class MetricSchema(BaseSchema): - key = fields.String(required=True) - value = fields.Float(required=True) - timestamp = fields.DateTime(required=True) - step = fields.Integer(required=True) - - @post_load - def make_metric(self, data): - return Metric(**data) - - -class ParamSchema(BaseSchema): - key = fields.String(required=True) - value = fields.String(required=True) - - @post_load - def make_param(self, data): - return Param(**data) - - -class TagSchema(BaseSchema): - key = fields.String(required=True) - value = fields.String(required=True) - - @post_load - def make_tag(self, data): - return Tag(**data) - - -class LifecycleStage(Enum): - ACTIVE = "active" - DELETED = "deleted" - - -class ExperimentSchema(BaseSchema): - id = fields.Integer(data_key="experimentId", required=True) - name = fields.String(required=True) - description = fields.String(required=True) - artifact_location = fields.String( - data_key="artifactLocation", required=True - ) - created_at = fields.DateTime(data_key="createdAt", required=True) - last_updated_at = fields.DateTime(data_key="lastUpdatedAt", required=True) - deleted_at = fields.DateTime(data_key="deletedAt", missing=None) - - @post_load - def make_experiment(self, data): - return Experiment(**data) - - -class ExperimentRunSchema(BaseSchema): - id = fields.UUID(data_key="runId", required=True) - run_number = fields.Integer(data_key="runNumber", required=True) - experiment_id = fields.Integer(data_key="experimentId", required=True) - name = fields.String(required=True) - parent_run_id = fields.UUID(data_key="parentRunId", missing=None) - artifact_location = fields.String( - data_key="artifactLocation", required=True - ) - status = EnumField(ExperimentRunStatus, by_value=True, required=True) - started_at = fields.DateTime(data_key="startedAt", required=True) - ended_at = fields.DateTime(data_key="endedAt", missing=None) - deleted_at = fields.DateTime(data_key="deletedAt", missing=None) - tags = fields.Nested(TagSchema, many=True, required=True) - params = fields.Nested(ParamSchema, many=True, required=True) - metrics = fields.Nested(MetricSchema, many=True, required=True) - - @post_load - def make_experiment_run(self, data): - return ExperimentRun(**data) - - -class ExperimentRunDataSchema(BaseSchema): - metrics = fields.List(fields.Nested(MetricSchema)) - params = fields.List(fields.Nested(ParamSchema)) - tags = fields.List(fields.Nested(TagSchema)) - - -class ExperimentRunInfoSchema(BaseSchema): - status = EnumField(ExperimentRunStatus, by_value=True, required=True) - ended_at = fields.DateTime(data_key="endedAt", missing=None) - - -class PaginationSchema(BaseSchema): - start = fields.Integer(required=True) - size = fields.Integer(required=True) - previous = fields.Nested(PageSchema, missing=None) - next = fields.Nested(PageSchema, missing=None) - - @post_load - def make_pagination(self, data): - return Pagination(**data) - - -class ListExperimentRunsResponseSchema(BaseSchema): - pagination = fields.Nested(PaginationSchema, required=True) - runs = fields.Nested(ExperimentRunSchema, many=True, required=True) - - @post_load - def make_list_runs_response_schema(self, data): - return ListExperimentRunsResponse(**data) - - -class CreateRunSchema(BaseSchema): - name = fields.String() - parent_run_id = fields.UUID(data_key="parentRunId") - started_at = fields.DateTime(data_key="startedAt") - artifact_location = fields.String(data_key="artifactLocation") - tags = fields.Nested(TagSchema, many=True, required=True) - - -class DeleteExperimentRunsResponseSchema(BaseSchema): - deleted_run_ids = fields.List( - fields.UUID(), data_key="deletedRunIds", required=True - ) - conflicted_run_ids = fields.List( - fields.UUID(), data_key="conflictedRunIds", required=True - ) - - @post_load - def make_delete_runs_response(self, data): - return DeleteExperimentRunsResponse(**data) - - -class RestoreExperimentRunsResponseSchema(BaseSchema): - restored_run_ids = fields.List( - fields.UUID(), data_key="restoredRunIds", required=True - ) - conflicted_run_ids = fields.List( - fields.UUID(), data_key="conflictedRunIds", required=True - ) - - @post_load - def make_restore_runs_response(self, data): - return RestoreExperimentRunsResponse(**data) - - -class ParamFilterValueField(fields.Field): - """Field that passes through strings or numbers.""" - - default_error_messages = { - "unsupported_type": "Param values must be of type str, int or float." - } - - def _serialize(self, value, attr, obj, **kwargs): - if isinstance(value, str): - field = fields.String() - elif isinstance(value, int) or isinstance(value, float): - field = fields.Number() - else: - self.fail("unsupported_type") - return field._serialize(value, attr, obj, **kwargs) - - -class OptionalField(fields.Field): - """Wrap another field, passing through Nones.""" - - def __init__(self, nested, *args, **kwargs): - self.nested = nested - super().__init__(*args, **kwargs) - - def _deserialize(self, value, *args, **kwargs): - if value is None: - return None - else: - return self.nested._deserialize(value, *args, **kwargs) - - def _serialize(self, value, *args, **kwargs): - if value is None: - return None - else: - return self.nested._serialize(value, *args, **kwargs) - - -class FilterValueField(fields.Field): - def __init__(self, other_field_type, *args, **kwargs): - self.other_field_type = other_field_type - super(FilterValueField, self).__init__(*args, **kwargs) - - def _serialize(self, value, attr, obj, **kwargs): - if obj.operator == ComparisonOperator.DEFINED: - field_cls = fields.Boolean - else: - field_cls = self.other_field_type - return field_cls()._serialize(value, attr, obj, **kwargs) - - -def _validate_discrete(operator): - if operator not in { - ComparisonOperator.DEFINED, - ComparisonOperator.EQUAL_TO, - ComparisonOperator.NOT_EQUAL_TO, - }: - raise ValidationError({"operator": "Not a discrete operator."}) - - -class ProjectIdFilterSchema(BaseSchema): - operator = EnumField(ComparisonOperator, by_value=True) - value = FilterValueField(fields.UUID) - by = fields.Constant("projectId", dump_only=True) - - @pre_dump - def check_operator(self, obj): - _validate_discrete(obj.operator) - return obj - - -class ExperimentIdFilterSchema(BaseSchema): - operator = EnumField(ComparisonOperator, by_value=True) - value = FilterValueField(fields.Integer) - by = fields.Constant("experimentId", dump_only=True) - - @pre_dump - def check_operator(self, obj): - _validate_discrete(obj.operator) - return obj - - -class RunIdFilterSchema(BaseSchema): - operator = EnumField(ComparisonOperator, by_value=True) - value = FilterValueField(fields.UUID) - by = fields.Constant("runId", dump_only=True) - - @pre_dump - def check_operator(self, obj): - _validate_discrete(obj.operator) - return obj - - -class DeletedAtFilterSchema(BaseSchema): - operator = EnumField(ComparisonOperator, by_value=True) - value = FilterValueField(fields.DateTime) - by = fields.Constant("deletedAt", dump_only=True) - - -class TagFilterSchema(BaseSchema): - key = fields.String() - operator = EnumField(ComparisonOperator, by_value=True) - value = FilterValueField(fields.String) - by = fields.Constant("tag", dump_only=True) - - @pre_dump - def check_operator(self, obj): - _validate_discrete(obj.operator) - return obj - - -class ParamFilterSchema(BaseSchema): - key = fields.String() - operator = EnumField(ComparisonOperator, by_value=True) - value = FilterValueField(ParamFilterValueField) - by = fields.Constant("param", dump_only=True) - - @pre_dump - def check_operator(self, obj): - if isinstance(obj.value, str): - _validate_discrete(obj.operator) - return obj - - -class MetricFilterSchema(BaseSchema): - key = fields.String() - operator = EnumField(ComparisonOperator, by_value=True) - value = FilterValueField(fields.Float) - by = fields.Constant("metric", dump_only=True) - - -class CompoundFilterSchema(BaseSchema): - operator = EnumField(LogicalOperator, by_value=True) - conditions = fields.List(fields.Nested("FilterSchema")) - - -class OneOfSchemaWithoutType(OneOfSchema): - def dump(self, *args, **kwargs): - data = super(OneOfSchemaWithoutType, self).dump(*args, **kwargs) - # Remove the type field added by marshmallow-oneofschema - return {k: v for k, v in data.items() if k != "type"} - - -class FilterSchema(OneOfSchemaWithoutType): - type_schemas = { - "ProjectIdFilter": ProjectIdFilterSchema, - "ExperimentIdFilter": ExperimentIdFilterSchema, - "RunIdFilter": RunIdFilterSchema, - "DeletedAtFilter": DeletedAtFilterSchema, - "TagFilter": TagFilterSchema, - "ParamFilter": ParamFilterSchema, - "MetricFilter": MetricFilterSchema, - "CompoundFilter": CompoundFilterSchema, - } - - -class StartedAtSortSchema(BaseSchema): - order = EnumField(SortOrder, by_value=True) - by = fields.Constant("startedAt", dump_only=True) - - -class RunNumberSortSchema(BaseSchema): - order = EnumField(SortOrder, by_value=True) - by = fields.Constant("runNumber", dump_only=True) - - -class DurationSortSchema(BaseSchema): - order = EnumField(SortOrder, by_value=True) - by = fields.Constant("duration", dump_only=True) - - -class TagSortSchema(BaseSchema): - key = fields.String() - order = EnumField(SortOrder, by_value=True) - by = fields.Constant("tag", dump_only=True) - - -class ParamSortSchema(BaseSchema): - key = fields.String() - order = EnumField(SortOrder, by_value=True) - by = fields.Constant("param", dump_only=True) - - -class MetricSortSchema(BaseSchema): - key = fields.String() - order = EnumField(SortOrder, by_value=True) - by = fields.Constant("metric", dump_only=True) - - -class SortSchema(OneOfSchemaWithoutType): - type_schemas = { - "StartedAtSort": StartedAtSortSchema, - "RunNumberSort": RunNumberSortSchema, - "DurationSort": DurationSortSchema, - "TagSort": TagSortSchema, - "ParamSort": ParamSortSchema, - "MetricSort": MetricSortSchema, - } - - -class RunQuerySchema(BaseSchema): - filter = OptionalField(fields.Nested(FilterSchema)) - sort = fields.List(fields.Nested(SortSchema)) - page = fields.Nested(PageSchema, missing=None) - - -class MetricDataPointSchema(BaseSchema): - """Deserialise a data point from the metric history endpoint. - - This schema is written with the expectation that it is not used alongside - the metric subsampling feature, which can result in null timestamp or step, - or a non-integer step. - """ - - value = fields.Float(required=True) - timestamp = fields.DateTime(required=True) - step = fields.Integer(required=True) - - @post_load - def make_metric(self, data): - return MetricDataPoint(**data) - - -class MetricHistorySchema(BaseSchema): - original_size = fields.Integer(data_key="originalSize", required=True) - subsampled = fields.Boolean(required=True) - key = fields.String(required=True) - history = fields.Nested(MetricDataPointSchema, many=True, required=True) - - @post_load - def make_history(self, data): - return MetricHistory(**data) - - class ExperimentClient(BaseClient): SERVICE_NAME = "atlas" diff --git a/faculty/clients/experiment/models.py b/faculty/clients/experiment/models.py new file mode 100644 index 00000000..00772458 --- /dev/null +++ b/faculty/clients/experiment/models.py @@ -0,0 +1,127 @@ +# Copyright 2018-2019 Faculty Science Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from collections import namedtuple +from enum import Enum + + +class ExperimentRunStatus(Enum): + RUNNING = "running" + FINISHED = "finished" + FAILED = "failed" + SCHEDULED = "scheduled" + KILLED = "killed" + + +Experiment = namedtuple( + "Experiment", + [ + "id", + "name", + "description", + "artifact_location", + "created_at", + "last_updated_at", + "deleted_at", + ], +) + + +ExperimentRun = namedtuple( + "ExperimentRun", + [ + "id", + "run_number", + "experiment_id", + "name", + "parent_run_id", + "artifact_location", + "status", + "started_at", + "ended_at", + "deleted_at", + "tags", + "params", + "metrics", + ], +) + + +class LifecycleStage(Enum): + ACTIVE = "active" + DELETED = "deleted" + + +Metric = namedtuple("Metric", ["key", "value", "timestamp", "step"]) +Param = namedtuple("Param", ["key", "value"]) +Tag = namedtuple("Tag", ["key", "value"]) + +Page = namedtuple("Page", ["start", "limit"]) +Pagination = namedtuple("Pagination", ["start", "size", "previous", "next"]) +ListExperimentRunsResponse = namedtuple( + "ListExperimentRunsResponse", ["runs", "pagination"] +) +DeleteExperimentRunsResponse = namedtuple( + "DeleteExperimentRunsResponse", ["deleted_run_ids", "conflicted_run_ids"] +) +RestoreExperimentRunsResponse = namedtuple( + "RestoreExperimentRunsResponse", ["restored_run_ids", "conflicted_run_ids"] +) + + +class ComparisonOperator(Enum): + DEFINED = "defined" + EQUAL_TO = "eq" + NOT_EQUAL_TO = "ne" + LESS_THAN = "lt" + LESS_THAN_OR_EQUAL_TO = "le" + GREATER_THAN = "gt" + GREATER_THAN_OR_EQUAL_TO = "ge" + + +class LogicalOperator(Enum): + AND = "and" + OR = "or" + + +ProjectIdFilter = namedtuple("ProjectIdFilter", ["operator", "value"]) +ExperimentIdFilter = namedtuple("ExperimentIdFilter", ["operator", "value"]) +RunIdFilter = namedtuple("RunIdFilter", ["operator", "value"]) +DeletedAtFilter = namedtuple("DeletedAtFilter", ["operator", "value"]) +TagFilter = namedtuple("TagFilter", ["key", "operator", "value"]) +ParamFilter = namedtuple("ParamFilter", ["key", "operator", "value"]) +MetricFilter = namedtuple("MetricFilter", ["key", "operator", "value"]) + +CompoundFilter = namedtuple("CompoundFilter", ["operator", "conditions"]) + +StartedAtSort = namedtuple("StartedAtSort", ["order"]) +RunNumberSort = namedtuple("RunNumberSort", ["order"]) +DurationSort = namedtuple("DurationSort", ["order"]) +TagSort = namedtuple("TagSort", ["key", "order"]) +ParamSort = namedtuple("ParamSort", ["key", "order"]) +MetricSort = namedtuple("MetricSort", ["key", "order"]) + + +class SortOrder(Enum): + ASC = "asc" + DESC = "desc" + + +RunQuery = namedtuple("RunQuery", ["filter", "sort", "page"]) +MetricDataPoint = namedtuple("Metric", ["value", "timestamp", "step"]) + +MetricHistory = namedtuple( + "MetricHistory", ["original_size", "subsampled", "key", "history"] +) diff --git a/faculty/clients/experiment/schemas.py b/faculty/clients/experiment/schemas.py new file mode 100644 index 00000000..7f0ab1ea --- /dev/null +++ b/faculty/clients/experiment/schemas.py @@ -0,0 +1,412 @@ +# Copyright 2018-2019 Faculty Science Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from marshmallow import fields, post_load, pre_dump, ValidationError +from marshmallow_enum import EnumField +from marshmallow_oneofschema import OneOfSchema + +from faculty.clients.base import BaseSchema +from faculty.clients.experiment.models import ( + ComparisonOperator, + DeleteExperimentRunsResponse, + Experiment, + ExperimentRun, + ExperimentRunStatus, + ListExperimentRunsResponse, + LogicalOperator, + Metric, + MetricDataPoint, + MetricHistory, + Page, + Pagination, + Param, + RestoreExperimentRunsResponse, + SortOrder, + Tag, +) + + +class PageSchema(BaseSchema): + start = fields.Integer(required=True) + limit = fields.Integer(required=True) + + @post_load + def make_page(self, data): + return Page(**data) + + +class MetricSchema(BaseSchema): + key = fields.String(required=True) + value = fields.Float(required=True) + timestamp = fields.DateTime(required=True) + step = fields.Integer(required=True) + + @post_load + def make_metric(self, data): + return Metric(**data) + + +class ParamSchema(BaseSchema): + key = fields.String(required=True) + value = fields.String(required=True) + + @post_load + def make_param(self, data): + return Param(**data) + + +class TagSchema(BaseSchema): + key = fields.String(required=True) + value = fields.String(required=True) + + @post_load + def make_tag(self, data): + return Tag(**data) + + +class ExperimentSchema(BaseSchema): + id = fields.Integer(data_key="experimentId", required=True) + name = fields.String(required=True) + description = fields.String(required=True) + artifact_location = fields.String( + data_key="artifactLocation", required=True + ) + created_at = fields.DateTime(data_key="createdAt", required=True) + last_updated_at = fields.DateTime(data_key="lastUpdatedAt", required=True) + deleted_at = fields.DateTime(data_key="deletedAt", missing=None) + + @post_load + def make_experiment(self, data): + return Experiment(**data) + + +class ExperimentRunSchema(BaseSchema): + id = fields.UUID(data_key="runId", required=True) + run_number = fields.Integer(data_key="runNumber", required=True) + experiment_id = fields.Integer(data_key="experimentId", required=True) + name = fields.String(required=True) + parent_run_id = fields.UUID(data_key="parentRunId", missing=None) + artifact_location = fields.String( + data_key="artifactLocation", required=True + ) + status = EnumField(ExperimentRunStatus, by_value=True, required=True) + started_at = fields.DateTime(data_key="startedAt", required=True) + ended_at = fields.DateTime(data_key="endedAt", missing=None) + deleted_at = fields.DateTime(data_key="deletedAt", missing=None) + tags = fields.Nested(TagSchema, many=True, required=True) + params = fields.Nested(ParamSchema, many=True, required=True) + metrics = fields.Nested(MetricSchema, many=True, required=True) + + @post_load + def make_experiment_run(self, data): + return ExperimentRun(**data) + + +class ExperimentRunDataSchema(BaseSchema): + metrics = fields.List(fields.Nested(MetricSchema)) + params = fields.List(fields.Nested(ParamSchema)) + tags = fields.List(fields.Nested(TagSchema)) + + +class ExperimentRunInfoSchema(BaseSchema): + status = EnumField(ExperimentRunStatus, by_value=True, required=True) + ended_at = fields.DateTime(data_key="endedAt", missing=None) + + +class PaginationSchema(BaseSchema): + start = fields.Integer(required=True) + size = fields.Integer(required=True) + previous = fields.Nested(PageSchema, missing=None) + next = fields.Nested(PageSchema, missing=None) + + @post_load + def make_pagination(self, data): + return Pagination(**data) + + +class ListExperimentRunsResponseSchema(BaseSchema): + pagination = fields.Nested(PaginationSchema, required=True) + runs = fields.Nested(ExperimentRunSchema, many=True, required=True) + + @post_load + def make_list_runs_response_schema(self, data): + return ListExperimentRunsResponse(**data) + + +class CreateRunSchema(BaseSchema): + name = fields.String() + parent_run_id = fields.UUID(data_key="parentRunId") + started_at = fields.DateTime(data_key="startedAt") + artifact_location = fields.String(data_key="artifactLocation") + tags = fields.Nested(TagSchema, many=True, required=True) + + +class DeleteExperimentRunsResponseSchema(BaseSchema): + deleted_run_ids = fields.List( + fields.UUID(), data_key="deletedRunIds", required=True + ) + conflicted_run_ids = fields.List( + fields.UUID(), data_key="conflictedRunIds", required=True + ) + + @post_load + def make_delete_runs_response(self, data): + return DeleteExperimentRunsResponse(**data) + + +class RestoreExperimentRunsResponseSchema(BaseSchema): + restored_run_ids = fields.List( + fields.UUID(), data_key="restoredRunIds", required=True + ) + conflicted_run_ids = fields.List( + fields.UUID(), data_key="conflictedRunIds", required=True + ) + + @post_load + def make_restore_runs_response(self, data): + return RestoreExperimentRunsResponse(**data) + + +class ParamFilterValueField(fields.Field): + """Field that passes through strings or numbers.""" + + default_error_messages = { + "unsupported_type": "Param values must be of type str, int or float." + } + + def _serialize(self, value, attr, obj, **kwargs): + if isinstance(value, str): + field = fields.String() + elif isinstance(value, int) or isinstance(value, float): + field = fields.Number() + else: + self.fail("unsupported_type") + return field._serialize(value, attr, obj, **kwargs) + + +class OptionalField(fields.Field): + """Wrap another field, passing through Nones.""" + + def __init__(self, nested, *args, **kwargs): + self.nested = nested + super().__init__(*args, **kwargs) + + def _deserialize(self, value, *args, **kwargs): + if value is None: + return None + else: + return self.nested._deserialize(value, *args, **kwargs) + + def _serialize(self, value, *args, **kwargs): + if value is None: + return None + else: + return self.nested._serialize(value, *args, **kwargs) + + +class FilterValueField(fields.Field): + def __init__(self, other_field_type, *args, **kwargs): + self.other_field_type = other_field_type + super(FilterValueField, self).__init__(*args, **kwargs) + + def _serialize(self, value, attr, obj, **kwargs): + if obj.operator == ComparisonOperator.DEFINED: + field_cls = fields.Boolean + else: + field_cls = self.other_field_type + return field_cls()._serialize(value, attr, obj, **kwargs) + + +def _validate_discrete(operator): + if operator not in { + ComparisonOperator.DEFINED, + ComparisonOperator.EQUAL_TO, + ComparisonOperator.NOT_EQUAL_TO, + }: + raise ValidationError({"operator": "Not a discrete operator."}) + + +class ProjectIdFilterSchema(BaseSchema): + operator = EnumField(ComparisonOperator, by_value=True) + value = FilterValueField(fields.UUID) + by = fields.Constant("projectId", dump_only=True) + + @pre_dump + def check_operator(self, obj): + _validate_discrete(obj.operator) + return obj + + +class ExperimentIdFilterSchema(BaseSchema): + operator = EnumField(ComparisonOperator, by_value=True) + value = FilterValueField(fields.Integer) + by = fields.Constant("experimentId", dump_only=True) + + @pre_dump + def check_operator(self, obj): + _validate_discrete(obj.operator) + return obj + + +class RunIdFilterSchema(BaseSchema): + operator = EnumField(ComparisonOperator, by_value=True) + value = FilterValueField(fields.UUID) + by = fields.Constant("runId", dump_only=True) + + @pre_dump + def check_operator(self, obj): + _validate_discrete(obj.operator) + return obj + + +class DeletedAtFilterSchema(BaseSchema): + operator = EnumField(ComparisonOperator, by_value=True) + value = FilterValueField(fields.DateTime) + by = fields.Constant("deletedAt", dump_only=True) + + +class TagFilterSchema(BaseSchema): + key = fields.String() + operator = EnumField(ComparisonOperator, by_value=True) + value = FilterValueField(fields.String) + by = fields.Constant("tag", dump_only=True) + + @pre_dump + def check_operator(self, obj): + _validate_discrete(obj.operator) + return obj + + +class ParamFilterSchema(BaseSchema): + key = fields.String() + operator = EnumField(ComparisonOperator, by_value=True) + value = FilterValueField(ParamFilterValueField) + by = fields.Constant("param", dump_only=True) + + @pre_dump + def check_operator(self, obj): + if isinstance(obj.value, str): + _validate_discrete(obj.operator) + return obj + + +class MetricFilterSchema(BaseSchema): + key = fields.String() + operator = EnumField(ComparisonOperator, by_value=True) + value = FilterValueField(fields.Float) + by = fields.Constant("metric", dump_only=True) + + +class CompoundFilterSchema(BaseSchema): + operator = EnumField(LogicalOperator, by_value=True) + conditions = fields.List(fields.Nested("FilterSchema")) + + +class OneOfSchemaWithoutType(OneOfSchema): + def dump(self, *args, **kwargs): + data = super(OneOfSchemaWithoutType, self).dump(*args, **kwargs) + # Remove the type field added by marshmallow-oneofschema + return {k: v for k, v in data.items() if k != "type"} + + +class FilterSchema(OneOfSchemaWithoutType): + type_schemas = { + "ProjectIdFilter": ProjectIdFilterSchema, + "ExperimentIdFilter": ExperimentIdFilterSchema, + "RunIdFilter": RunIdFilterSchema, + "DeletedAtFilter": DeletedAtFilterSchema, + "TagFilter": TagFilterSchema, + "ParamFilter": ParamFilterSchema, + "MetricFilter": MetricFilterSchema, + "CompoundFilter": CompoundFilterSchema, + } + + +class StartedAtSortSchema(BaseSchema): + order = EnumField(SortOrder, by_value=True) + by = fields.Constant("startedAt", dump_only=True) + + +class RunNumberSortSchema(BaseSchema): + order = EnumField(SortOrder, by_value=True) + by = fields.Constant("runNumber", dump_only=True) + + +class DurationSortSchema(BaseSchema): + order = EnumField(SortOrder, by_value=True) + by = fields.Constant("duration", dump_only=True) + + +class TagSortSchema(BaseSchema): + key = fields.String() + order = EnumField(SortOrder, by_value=True) + by = fields.Constant("tag", dump_only=True) + + +class ParamSortSchema(BaseSchema): + key = fields.String() + order = EnumField(SortOrder, by_value=True) + by = fields.Constant("param", dump_only=True) + + +class MetricSortSchema(BaseSchema): + key = fields.String() + order = EnumField(SortOrder, by_value=True) + by = fields.Constant("metric", dump_only=True) + + +class SortSchema(OneOfSchemaWithoutType): + type_schemas = { + "StartedAtSort": StartedAtSortSchema, + "RunNumberSort": RunNumberSortSchema, + "DurationSort": DurationSortSchema, + "TagSort": TagSortSchema, + "ParamSort": ParamSortSchema, + "MetricSort": MetricSortSchema, + } + + +class RunQuerySchema(BaseSchema): + filter = OptionalField(fields.Nested(FilterSchema)) + sort = fields.List(fields.Nested(SortSchema)) + page = fields.Nested(PageSchema, missing=None) + + +class MetricDataPointSchema(BaseSchema): + """Deserialise a data point from the metric history endpoint. + + This schema is written with the expectation that it is not used alongside + the metric subsampling feature, which can result in null timestamp or step, + or a non-integer step. + """ + + value = fields.Float(required=True) + timestamp = fields.DateTime(required=True) + step = fields.Integer(required=True) + + @post_load + def make_metric(self, data): + return MetricDataPoint(**data) + + +class MetricHistorySchema(BaseSchema): + original_size = fields.Integer(data_key="originalSize", required=True) + subsampled = fields.Boolean(required=True) + key = fields.String(required=True) + history = fields.Nested(MetricDataPointSchema, many=True, required=True) + + @post_load + def make_history(self, data): + return MetricHistory(**data) diff --git a/tests/clients/experiment/test_init.py b/tests/clients/experiment/test_init.py index d6db90b0..ff3e8d0f 100644 --- a/tests/clients/experiment/test_init.py +++ b/tests/clients/experiment/test_init.py @@ -13,672 +13,49 @@ # limitations under the License. -from datetime import datetime from uuid import uuid4 import pytest -from marshmallow import ValidationError -from pytz import UTC from faculty.clients.base import Conflict from faculty.clients.experiment import ( + ExperimentClient, + ExperimentDeleted, + ExperimentNameConflict, + ParamConflict, +) +from faculty.clients.experiment.models import ( ComparisonOperator, CompoundFilter, - CreateRunSchema, - DeleteExperimentRunsResponse, - DeleteExperimentRunsResponseSchema, DeletedAtFilter, - DurationSort, - Experiment, - ExperimentClient, - ExperimentDeleted, ExperimentIdFilter, - ExperimentNameConflict, - ExperimentRun, - ExperimentRunDataSchema, - ExperimentRunSchema, - ExperimentRunStatus, - ExperimentSchema, - FilterSchema, LifecycleStage, - ListExperimentRunsResponse, - ListExperimentRunsResponseSchema, LogicalOperator, Metric, - MetricDataPoint, - MetricFilter, - MetricHistory, - MetricHistorySchema, - MetricSchema, - MetricSort, Page, - PageSchema, - Pagination, - PaginationSchema, - Param, - ParamConflict, - ParamFilter, - ParamSchema, - ParamSort, - ProjectIdFilter, - RestoreExperimentRunsResponse, - RestoreExperimentRunsResponseSchema, RunIdFilter, - RunNumberSort, RunQuery, - RunQuerySchema, - SortOrder, - SortSchema, - StartedAtSort, - Tag, - TagFilter, - TagSchema, - TagSort, ) + PROJECT_ID = uuid4() -EXPERIMENT_ID = 661 +EXPERIMENT_ID = 234 EXPERIMENT_RUN_ID = uuid4() -EXPERIMENT_RUN_NUMBER = 3 -EXPERIMENT_RUN_NAME = "run name" PARENT_RUN_ID = uuid4() -CREATED_AT = datetime(2018, 3, 10, 11, 32, 6, 247000, tzinfo=UTC) -CREATED_AT_STRING = "2018-03-10T11:32:06.247Z" -LAST_UPDATED_AT = datetime(2018, 3, 10, 11, 32, 30, 172000, tzinfo=UTC) -LAST_UPDATED_AT_STRING = "2018-03-10T11:32:30.172Z" -DELETED_AT = datetime(2018, 3, 10, 11, 37, 42, 482000, tzinfo=UTC) -DELETED_AT_STRING = "2018-03-10T11:37:42.482Z" -DELETED_AT_STRING_PYTHON = "2018-03-10T11:37:42.482000+00:00" - -EXPERIMENT = Experiment( - id=EXPERIMENT_ID, - name="experiment name", - description="experiment description", - artifact_location="https://example.com", - created_at=CREATED_AT, - last_updated_at=LAST_UPDATED_AT, - deleted_at=DELETED_AT, -) -EXPERIMENT_BODY = { - "experimentId": EXPERIMENT_ID, - "name": EXPERIMENT.name, - "description": EXPERIMENT.description, - "artifactLocation": EXPERIMENT.artifact_location, - "createdAt": CREATED_AT_STRING, - "lastUpdatedAt": LAST_UPDATED_AT_STRING, - "deletedAt": DELETED_AT_STRING, -} - -RUN_ID = uuid4() -RUN_STARTED_AT = datetime(2018, 3, 10, 11, 39, 12, 110000, tzinfo=UTC) -RUN_STARTED_AT_NO_TIMEZONE = datetime(2018, 3, 10, 11, 39, 12, 110000) -RUN_STARTED_AT_STRING_PYTHON = "2018-03-10T11:39:12.110000+00:00" -RUN_STARTED_AT_STRING_JAVA = "2018-03-10T11:39:12.11Z" -RUN_ENDED_AT = datetime(2018, 3, 10, 11, 39, 15, 110000, tzinfo=UTC) -RUN_ENDED_AT_STRING = "2018-03-10T11:39:15.11Z" - -TAG = Tag(key="tag-key", value="tag-value") -TAG_BODY = {"key": "tag-key", "value": "tag-value"} - -OTHER_TAG = Tag(key="other-tag-key", value="other-tag-value") -OTHER_TAG_BODY = {"key": "other-tag-key", "value": "other-tag-value"} - -PARAM = Param(key="param-key", value="param-value") -PARAM_BODY = {"key": "param-key", "value": "param-value"} - -METRIC_KEY = "metric-key" -METRIC = Metric( - key=METRIC_KEY, - value=123.0, - timestamp=datetime(2018, 3, 12, 16, 20, 22, 122000, tzinfo=UTC), - step=0, -) -METRIC_BODY = { - "key": METRIC.key, - "value": METRIC.value, - "timestamp": "2018-03-12T16:20:22.122000+00:00", - "step": METRIC.step, -} - -METRIC_DATA_POINT = MetricDataPoint( - value=METRIC.value, timestamp=METRIC.timestamp, step=METRIC.step -) -METRIC_DATA_POINT_BODY = { - "value": METRIC_BODY["value"], - "timestamp": METRIC_BODY["timestamp"], - "step": METRIC_BODY["step"], -} - -METRIC_HISTORY = MetricHistory( - original_size=1, - subsampled=False, - key=METRIC_KEY, - history=[METRIC_DATA_POINT], -) -METRIC_HISTORY_BODY = { - "originalSize": METRIC_HISTORY.original_size, - "subsampled": METRIC_HISTORY.subsampled, - "key": METRIC_HISTORY.key, - "history": [METRIC_DATA_POINT_BODY], -} - -EXPERIMENT_RUN = ExperimentRun( - id=EXPERIMENT_RUN_ID, - run_number=EXPERIMENT_RUN_NUMBER, - name=EXPERIMENT_RUN_NAME, - parent_run_id=PARENT_RUN_ID, - experiment_id=EXPERIMENT.id, - artifact_location="faculty:", - status=ExperimentRunStatus.RUNNING, - started_at=RUN_STARTED_AT, - ended_at=RUN_ENDED_AT, - deleted_at=DELETED_AT, - tags=[TAG], - params=[PARAM], - metrics=[METRIC], -) -EXPERIMENT_RUN_BODY = { - "experimentId": EXPERIMENT.id, - "runId": str(EXPERIMENT_RUN_ID), - "runNumber": EXPERIMENT_RUN_NUMBER, - "name": EXPERIMENT_RUN_NAME, - "parentRunId": str(PARENT_RUN_ID), - "artifactLocation": "faculty:", - "status": "running", - "startedAt": RUN_STARTED_AT_STRING_JAVA, - "endedAt": RUN_ENDED_AT_STRING, - "deletedAt": DELETED_AT_STRING, - "tags": [TAG_BODY], - "metrics": [METRIC_BODY], - "params": [PARAM_BODY], -} - -EXPERIMENT_RUN_DATA_BODY = { - "metrics": [METRIC_BODY], - "params": [PARAM_BODY], - "tags": [TAG_BODY], -} - - -PAGE = Page(start=3, limit=10) -PAGE_BODY = {"start": PAGE.start, "limit": PAGE.limit} - -PAGINATION = Pagination( - start=20, - size=10, - previous=Page(start=10, limit=10), - next=Page(start=30, limit=10), -) -PAGINATION_BODY = { - "start": PAGINATION.start, - "size": PAGINATION.size, - "previous": { - "start": PAGINATION.previous.start, - "limit": PAGINATION.previous.limit, - }, - "next": {"start": PAGINATION.next.start, "limit": PAGINATION.next.limit}, -} - -LIST_EXPERIMENT_RUNS_RESPONSE = ListExperimentRunsResponse( - runs=[EXPERIMENT_RUN], pagination=PAGINATION -) -LIST_EXPERIMENT_RUNS_RESPONSE_BODY = { - "runs": [EXPERIMENT_RUN_BODY], - "pagination": PAGINATION_BODY, -} - -DELETE_EXPERIMENT_RUNS_RESPONSE = DeleteExperimentRunsResponse( - deleted_run_ids=[uuid4(), uuid4()], conflicted_run_ids=[uuid4(), uuid4()] -) -DELETE_EXPERIMENT_RUNS_RESPONSE_BODY = { - "deletedRunIds": [ - str(run_id) - for run_id in DELETE_EXPERIMENT_RUNS_RESPONSE.deleted_run_ids - ], - "conflictedRunIds": [ - str(run_id) - for run_id in DELETE_EXPERIMENT_RUNS_RESPONSE.conflicted_run_ids - ], -} - -RESTORE_EXPERIMENT_RUNS_RESPONSE = RestoreExperimentRunsResponse( - restored_run_ids=[uuid4(), uuid4()], conflicted_run_ids=[uuid4(), uuid4()] -) -RESTORE_EXPERIMENT_RUNS_RESPONSE_BODY = { - "restoredRunIds": [ - str(run_id) - for run_id in RESTORE_EXPERIMENT_RUNS_RESPONSE.restored_run_ids - ], - "conflictedRunIds": [ - str(run_id) - for run_id in RESTORE_EXPERIMENT_RUNS_RESPONSE.conflicted_run_ids - ], -} - - -def test_experiment_schema(): - data = ExperimentSchema().load(EXPERIMENT_BODY) - assert data == EXPERIMENT - - -def test_experiment_schema_nullable_deleted_at(): - body = EXPERIMENT_BODY.copy() - body["deletedAt"] = None - data = ExperimentSchema().load(body) - assert data.deleted_at is None - - -def test_experiment_schema_invalid(): - with pytest.raises(ValidationError): - ExperimentSchema().load({}) - - -def test_experiment_run_schema(): - data = ExperimentRunSchema().load(EXPERIMENT_RUN_BODY) - assert data == EXPERIMENT_RUN - - -@pytest.mark.parametrize( - "data_key, field", - [ - ("parentRunId", "parent_run_id"), - ("endedAt", "ended_at"), - ("deletedAt", "deleted_at"), - ], -) -def test_experiment_run_schema_nullable_field(data_key, field): - body = EXPERIMENT_RUN_BODY.copy() - del body[data_key] - data = ExperimentRunSchema().load(body) - assert getattr(data, field) is None - - -@pytest.mark.parametrize("parent_run_id", [None, PARENT_RUN_ID]) -@pytest.mark.parametrize( - "started_at", - [RUN_STARTED_AT, RUN_STARTED_AT_NO_TIMEZONE], - ids=["timezone", "no timezone"], -) -@pytest.mark.parametrize("artifact_location", [None, "faculty:project-id"]) -@pytest.mark.parametrize("tags", [[], [{"key": "key", "value": "value"}]]) -def test_create_run_schema(parent_run_id, started_at, artifact_location, tags): - data = CreateRunSchema().dump( - { - "name": EXPERIMENT_RUN_NAME, - "parent_run_id": parent_run_id, - "started_at": started_at, - "artifact_location": artifact_location, - "tags": tags, - } - ) - assert data == { - "name": EXPERIMENT_RUN_NAME, - "parentRunId": None if parent_run_id is None else str(parent_run_id), - "startedAt": RUN_STARTED_AT_STRING_PYTHON, - "artifactLocation": artifact_location, - "tags": tags, - } - - -def test_metric_schema(): - data = MetricSchema().load(METRIC_BODY) - assert data == METRIC - - -def test_param_schema(): - data = ParamSchema().load(PARAM_BODY) - assert data == PARAM - - -def test_tag_schema(): - data = TagSchema().load(TAG_BODY) - assert data == TAG - - -def test_tag_schema_dump(): - data = TagSchema().dump(TAG_BODY) - assert data == TAG_BODY - - -def test_experiment_run_data_schema(): - data = ExperimentRunDataSchema().dump( - {"metrics": [METRIC], "params": [PARAM], "tags": [TAG]} - ) - assert data == EXPERIMENT_RUN_DATA_BODY - - -def test_experiment_run_data_schema_empty(): - data = ExperimentRunDataSchema().dump({}) - assert data == {} - - -def test_experiment_run_data_schema_multiple(): - data = ExperimentRunDataSchema().dump({"tags": [TAG, OTHER_TAG]}) - assert data == {"tags": [TAG_BODY, OTHER_TAG_BODY]} - - -PROJECT_ID_FILTER = ProjectIdFilter(ComparisonOperator.EQUAL_TO, PROJECT_ID) -PROJECT_ID_FILTER_BODY = { - "by": "projectId", - "operator": "eq", - "value": str(PROJECT_ID), -} - -TAG_FILTER = TagFilter("tag-key", ComparisonOperator.EQUAL_TO, "tag-value") -TAG_FILTER_BODY = { - "by": "tag", - "key": "tag-key", - "operator": "eq", - "value": "tag-value", -} - - -DEFINED_TEST_CASES = [ - (ComparisonOperator.DEFINED, False, "defined", False), - (ComparisonOperator.DEFINED, True, "defined", True), - (ComparisonOperator.DEFINED, 0, "defined", False), - (ComparisonOperator.DEFINED, 1, "defined", True), -] - - -def discrete_test_cases(value, expected): - return DEFINED_TEST_CASES + [ - (ComparisonOperator.EQUAL_TO, value, "eq", expected), - (ComparisonOperator.NOT_EQUAL_TO, value, "ne", expected), - ] - - -def continuous_test_cases(value, expected): - return discrete_test_cases(value, expected) + [ - (ComparisonOperator.GREATER_THAN, value, "gt", expected), - (ComparisonOperator.GREATER_THAN_OR_EQUAL_TO, value, "ge", expected), - (ComparisonOperator.LESS_THAN, value, "lt", expected), - (ComparisonOperator.LESS_THAN_OR_EQUAL_TO, value, "le", expected), - ] - - -@pytest.mark.parametrize( - "operator, value, expected_operator, expected_value", - discrete_test_cases(PROJECT_ID, str(PROJECT_ID)), -) -def test_filter_schema_project_id( - operator, value, expected_operator, expected_value -): - filter = ProjectIdFilter(operator, value) - data = FilterSchema().dump(filter) - assert data == { - "by": "projectId", - "operator": expected_operator, - "value": expected_value, - } - - -@pytest.mark.parametrize( - "operator, value, expected_operator, expected_value", - discrete_test_cases(EXPERIMENT_ID, EXPERIMENT_ID), -) -def test_filter_schema_experiment_id( - operator, value, expected_operator, expected_value -): - filter = ExperimentIdFilter(operator, value) - data = FilterSchema().dump(filter) - assert data == { - "by": "experimentId", - "operator": expected_operator, - "value": expected_value, - } - - -@pytest.mark.parametrize( - "operator, value, expected_operator, expected_value", - discrete_test_cases(RUN_ID, str(RUN_ID)), -) -def test_filter_schema_run_id( - operator, value, expected_operator, expected_value -): - filter = RunIdFilter(operator, value) - data = FilterSchema().dump(filter) - assert data == { - "by": "runId", - "operator": expected_operator, - "value": expected_value, - } - - -@pytest.mark.parametrize( - "operator, value, expected_operator, expected_value", - continuous_test_cases(DELETED_AT, DELETED_AT_STRING_PYTHON), -) -def test_filter_schema_deleted_at( - operator, value, expected_operator, expected_value -): - filter = DeletedAtFilter(operator, value) - data = FilterSchema().dump(filter) - assert data == { - "by": "deletedAt", - "operator": expected_operator, - "value": expected_value, - } - - -@pytest.mark.parametrize( - "operator, value, expected_operator, expected_value", - discrete_test_cases("tag-value", "tag-value"), -) -def test_filter_schema_tag(operator, value, expected_operator, expected_value): - filter = TagFilter("tag-key", operator, value) - data = FilterSchema().dump(filter) - assert data == { - "by": "tag", - "key": "tag-key", - "operator": expected_operator, - "value": expected_value, - } - - -@pytest.mark.parametrize( - "operator, value, expected_operator, expected_value", - discrete_test_cases("param-value", "param-value") - + continuous_test_cases(123.2, 123.2), -) -def test_filter_schema_param( - operator, value, expected_operator, expected_value -): - filter = ParamFilter("param-key", operator, value) - data = FilterSchema().dump(filter) - assert data == { - "by": "param", - "key": "param-key", - "operator": expected_operator, - "value": expected_value, - } - - -@pytest.mark.parametrize( - "operator, value, expected_operator, expected_value", - continuous_test_cases(45.6, 45.6), -) -def test_filter_schema_metric( - operator, value, expected_operator, expected_value -): - filter = MetricFilter("metric-key", operator, value) - data = FilterSchema().dump(filter) - assert data == { - "by": "metric", - "key": "metric-key", - "operator": expected_operator, - "value": expected_value, - } - - -@pytest.mark.parametrize( - "filter_type", - [ProjectIdFilter, ExperimentIdFilter, RunIdFilter, DeletedAtFilter], -) -def test_filter_schema_invalid_value_no_key(filter_type): - filter = filter_type(ComparisonOperator.EQUAL_TO, "invalid") - with pytest.raises(ValidationError): - FilterSchema().dump(filter) - - -@pytest.mark.parametrize( - "filter_type, value", [(ParamFilter, None), (MetricFilter, "invalid")] -) -def test_filter_schema_invalid_value_with_key(filter_type, value): - filter = filter_type("key", ComparisonOperator.EQUAL_TO, value) - with pytest.raises(ValidationError): - FilterSchema().dump(filter) - - -@pytest.mark.parametrize( - "filter_type, value", - [ - (ProjectIdFilter, PROJECT_ID), - (ExperimentIdFilter, EXPERIMENT_ID), - (RunIdFilter, RUN_ID), - ], -) -@pytest.mark.parametrize( - "operator", - [ - ComparisonOperator.GREATER_THAN, - ComparisonOperator.GREATER_THAN_OR_EQUAL_TO, - ComparisonOperator.LESS_THAN, - ComparisonOperator.LESS_THAN_OR_EQUAL_TO, - ], -) -def test_filter_schema_invalid_operator_no_key(filter_type, value, operator): - filter = filter_type(operator, value) - with pytest.raises(ValidationError, match="Not a discrete operator"): - FilterSchema().dump(filter) - - -@pytest.mark.parametrize( - "filter_type, value", - [(TagFilter, "tag-value"), (ParamFilter, "param-string-value")], -) -@pytest.mark.parametrize( - "operator", - [ - ComparisonOperator.GREATER_THAN, - ComparisonOperator.GREATER_THAN_OR_EQUAL_TO, - ComparisonOperator.LESS_THAN, - ComparisonOperator.LESS_THAN_OR_EQUAL_TO, - ], -) -def test_filter_schema_invalid_operator_with_key(filter_type, value, operator): - filter = filter_type("key", operator, value) - with pytest.raises(ValidationError, match="Not a discrete operator"): - FilterSchema().dump(filter) - - -@pytest.mark.parametrize( - "operator, expected_operator", - [(LogicalOperator.AND, "and"), (LogicalOperator.OR, "or")], -) -def test_filter_schema_compound(operator, expected_operator): - filter = CompoundFilter(operator, [PROJECT_ID_FILTER, TAG_FILTER]) - data = FilterSchema().dump(filter) - assert data == { - "operator": expected_operator, - "conditions": [PROJECT_ID_FILTER_BODY, TAG_FILTER_BODY], - } - - -def test_filter_schema_nested(): - filter = CompoundFilter( - LogicalOperator.AND, - [ - CompoundFilter( - LogicalOperator.AND, [PROJECT_ID_FILTER, TAG_FILTER] - ), - CompoundFilter( - LogicalOperator.OR, [TAG_FILTER, PROJECT_ID_FILTER] - ), - ], - ) - data = FilterSchema().dump(filter) - assert data == { - "operator": "and", - "conditions": [ - { - "operator": "and", - "conditions": [PROJECT_ID_FILTER_BODY, TAG_FILTER_BODY], - }, - { - "operator": "or", - "conditions": [TAG_FILTER_BODY, PROJECT_ID_FILTER_BODY], - }, - ], - } - - -@pytest.mark.parametrize( - "sort_type, by", - [ - (StartedAtSort, "startedAt"), - (RunNumberSort, "runNumber"), - (DurationSort, "duration"), - ], -) -@pytest.mark.parametrize( - "order, expected_order", [(SortOrder.ASC, "asc"), (SortOrder.DESC, "desc")] -) -def test_sort_schema_no_tag(sort_type, by, order, expected_order): - sort = sort_type(order) - data = SortSchema().dump(sort) - assert data == {"by": by, "order": expected_order} - - -@pytest.mark.parametrize( - "sort_type, by", - [(TagSort, "tag"), (ParamSort, "param"), (MetricSort, "metric")], -) -@pytest.mark.parametrize( - "order, expected_order", [(SortOrder.ASC, "asc"), (SortOrder.DESC, "desc")] -) -def test_sort_schema_with_tag(sort_type, by, order, expected_order): - sort = sort_type("sort-key", order) - data = SortSchema().dump(sort) - assert data == {"by": by, "key": "sort-key", "order": expected_order} - - -def test_run_query_schema(mocker): - mocker.patch.object(FilterSchema, "dump") - mocker.patch.object(SortSchema, "dump") - mocker.patch.object(PageSchema, "dump") - - filter = mocker.Mock() - sorts = [mocker.Mock(), mocker.Mock()] - page = mocker.Mock() - - run_query = RunQuery(filter, sorts, page) - data = RunQuerySchema().dump(run_query) - - assert data == { - "filter": FilterSchema.dump.return_value, - "sort": [SortSchema.dump.return_value, SortSchema.dump.return_value], - "page": PageSchema.dump.return_value, - } - - -def test_run_query_schema_defaults(): - run_query = RunQuery(None, None, None) - data = RunQuerySchema().dump(run_query) - assert data == {"filter": None, "sort": None, "page": None} @pytest.mark.parametrize("description", [None, "experiment description"]) @pytest.mark.parametrize("artifact_location", [None, "s3://mybucket"]) def test_experiment_client_create(mocker, description, artifact_location): - mocker.patch.object(ExperimentClient, "_post", return_value=EXPERIMENT) + experiment = mocker.Mock() + mocker.patch.object(ExperimentClient, "_post", return_value=experiment) schema_mock = mocker.patch("faculty.clients.experiment.ExperimentSchema") client = ExperimentClient(mocker.Mock()) returned_experiment = client.create( PROJECT_ID, "experiment name", description, artifact_location ) - assert returned_experiment == EXPERIMENT + assert returned_experiment == experiment schema_mock.assert_called_once_with() ExperimentClient._post.assert_called_once_with( @@ -705,26 +82,28 @@ def test_experiment_client_create_name_conflict(mocker): def test_experiment_client_get(mocker): - mocker.patch.object(ExperimentClient, "_get", return_value=EXPERIMENT) + experiment = mocker.Mock() + mocker.patch.object(ExperimentClient, "_get", return_value=experiment) schema_mock = mocker.patch("faculty.clients.experiment.ExperimentSchema") client = ExperimentClient(mocker.Mock()) - returned_experiment = client.get(PROJECT_ID, EXPERIMENT.id) - assert returned_experiment == EXPERIMENT + returned_experiment = client.get(PROJECT_ID, EXPERIMENT_ID) + assert returned_experiment == experiment schema_mock.assert_called_once_with() ExperimentClient._get.assert_called_once_with( - "/project/{}/experiment/{}".format(PROJECT_ID, EXPERIMENT.id), + "/project/{}/experiment/{}".format(PROJECT_ID, EXPERIMENT_ID), schema_mock.return_value, ) def test_experiment_client_list(mocker): - mocker.patch.object(ExperimentClient, "_get", return_value=[EXPERIMENT]) + experiment = mocker.Mock() + mocker.patch.object(ExperimentClient, "_get", return_value=[experiment]) schema_mock = mocker.patch("faculty.clients.experiment.ExperimentSchema") client = ExperimentClient(mocker.Mock()) - assert client.list(PROJECT_ID) == [EXPERIMENT] + assert client.list(PROJECT_ID) == [experiment] schema_mock.assert_called_once_with(many=True) ExperimentClient._get.assert_called_once_with( @@ -735,14 +114,15 @@ def test_experiment_client_list(mocker): def test_experiment_client_list_lifecycle_filter(mocker): - mocker.patch.object(ExperimentClient, "_get", return_value=[EXPERIMENT]) + experiment = mocker.Mock() + mocker.patch.object(ExperimentClient, "_get", return_value=[experiment]) schema_mock = mocker.patch("faculty.clients.experiment.ExperimentSchema") client = ExperimentClient(mocker.Mock()) returned_experiments = client.list( PROJECT_ID, lifecycle_stage=LifecycleStage.ACTIVE ) - assert returned_experiments == [EXPERIMENT] + assert returned_experiments == [experiment] schema_mock.assert_called_once_with(many=True) ExperimentClient._get.assert_called_once_with( @@ -803,7 +183,8 @@ def test_restore(mocker): def test_experiment_create_run(mocker): - mocker.patch.object(ExperimentClient, "_post", return_value=EXPERIMENT_RUN) + run = mocker.Mock() + mocker.patch.object(ExperimentClient, "_post", return_value=run) request_schema_mock = mocker.patch( "faculty.clients.experiment.CreateRunSchema" ) @@ -811,6 +192,7 @@ def test_experiment_create_run(mocker): response_schema_mock = mocker.patch( "faculty.clients.experiment.ExperimentRunSchema" ) + run_name = mocker.Mock() started_at = mocker.Mock() artifact_location = mocker.Mock() @@ -818,17 +200,17 @@ def test_experiment_create_run(mocker): returned_run = client.create_run( PROJECT_ID, EXPERIMENT_ID, - EXPERIMENT_RUN_NAME, + run_name, started_at, PARENT_RUN_ID, artifact_location=artifact_location, ) - assert returned_run == EXPERIMENT_RUN + assert returned_run == run request_schema_mock.assert_called_once_with() dump_mock.assert_called_once_with( { - "name": EXPERIMENT_RUN_NAME, + "name": run_name, "parent_run_id": PARENT_RUN_ID, "started_at": started_at, "artifact_location": artifact_location, @@ -851,30 +233,29 @@ def test_experiment_create_run_experiment_deleted_conflict(mocker): exception = Conflict(response_mock, message, error_code) mocker.patch.object(ExperimentClient, "_post", side_effect=exception) - started_at = mocker.Mock() - artifact_location = mocker.Mock() client = ExperimentClient(mocker.Mock()) with pytest.raises(ExperimentDeleted, match=message): client.create_run( PROJECT_ID, EXPERIMENT_ID, - EXPERIMENT_RUN_NAME, - started_at, - PARENT_RUN_ID, - artifact_location=artifact_location, + name=mocker.Mock(), + started_at=mocker.Mock(), + parent_run_id=PARENT_RUN_ID, + artifact_location=mocker.Mock(), ) def test_experiment_client_get_run(mocker): - mocker.patch.object(ExperimentClient, "_get", return_value=EXPERIMENT_RUN) + run = mocker.Mock() + mocker.patch.object(ExperimentClient, "_get", return_value=run) schema_mock = mocker.patch( "faculty.clients.experiment.ExperimentRunSchema" ) client = ExperimentClient(mocker.Mock()) returned_run = client.get_run(PROJECT_ID, EXPERIMENT_RUN_ID) - assert returned_run == EXPERIMENT_RUN + assert returned_run == run schema_mock.assert_called_once_with() ExperimentClient._get.assert_called_once_with( @@ -883,60 +264,6 @@ def test_experiment_client_get_run(mocker): ) -def test_list_runs_schema(mocker): - data = ListExperimentRunsResponseSchema().load( - LIST_EXPERIMENT_RUNS_RESPONSE_BODY - ) - assert data == LIST_EXPERIMENT_RUNS_RESPONSE - - -def test_page_schema_load(): - data = PageSchema().load(PAGE_BODY) - assert data == PAGE - - -def test_page_schema_dump(): - data = PageSchema().dump(PAGE) - assert data == PAGE_BODY - - -def test_pagination_schema(): - data = PaginationSchema().load(PAGINATION_BODY) - assert data == PAGINATION - - -@pytest.mark.parametrize("field", ["previous", "next"]) -def test_pagination_schema_nullable_field(field): - body = PAGINATION_BODY.copy() - del body[field] - data = PaginationSchema().load(body) - assert getattr(data, field) is None - - -def test_delete_experiment_runs_response_schema(mocker): - data = DeleteExperimentRunsResponseSchema().load( - DELETE_EXPERIMENT_RUNS_RESPONSE_BODY - ) - assert data == DELETE_EXPERIMENT_RUNS_RESPONSE - - -def test_delete_experiment_runs_response_schema_invalid(mocker): - with pytest.raises(ValidationError): - DeleteExperimentRunsResponseSchema().load({}) - - -def test_restore_experiment_runs_response_schema(mocker): - data = RestoreExperimentRunsResponseSchema().load( - RESTORE_EXPERIMENT_RUNS_RESPONSE_BODY - ) - assert data == RESTORE_EXPERIMENT_RUNS_RESPONSE - - -def test_restore_experiment_runs_response_schema_invalid(mocker): - with pytest.raises(ValidationError): - RestoreExperimentRunsResponseSchema().load({}) - - def test_experiment_client_list_runs(mocker): mocker.patch.object(ExperimentClient, "query_runs") @@ -981,13 +308,15 @@ def test_experiment_client_list_runs_defaults(mocker): def test_experiment_client_query_runs(mocker): - mocker.patch.object( - ExperimentClient, "_post", return_value=LIST_EXPERIMENT_RUNS_RESPONSE - ) + list_response = mocker.Mock() + mocker.patch.object(ExperimentClient, "_post", return_value=list_response) response_schema_mock = mocker.patch( "faculty.clients.experiment.ListExperimentRunsResponseSchema" ) - request_dump_mock = mocker.patch.object(RunQuerySchema, "dump") + request_schema_mock = mocker.patch( + "faculty.clients.experiment.RunQuerySchema" + ) + request_dump_mock = request_schema_mock.return_value.dump filter = mocker.Mock() sort = mocker.Mock() @@ -997,7 +326,7 @@ def test_experiment_client_query_runs(mocker): PROJECT_ID, filter, sort, start=20, limit=10 ) - assert list_result == LIST_EXPERIMENT_RUNS_RESPONSE + assert list_result == list_response request_dump_mock.assert_called_once_with( RunQuery(filter, sort, Page(20, 10)) @@ -1017,18 +346,22 @@ def test_log_run_data(mocker): ) run_data_dump_mock = run_data_schema_mock.return_value.dump + metric = mocker.Mock() + param = mocker.Mock() + tag = mocker.Mock() + client = ExperimentClient(mocker.Mock()) client.log_run_data( PROJECT_ID, EXPERIMENT_RUN_ID, - metrics=[METRIC], - params=[PARAM], - tags=[TAG], + metrics=[metric], + params=[param], + tags=[tag], ) run_data_schema_mock.assert_called_once_with() run_data_dump_mock.assert_called_once_with( - {"metrics": [METRIC], "params": [PARAM], "tags": [TAG]} + {"metrics": [metric], "params": [param], "tags": [tag]} ) ExperimentClient._patch_raw.assert_called_once_with( "/project/{}/run/{}/data".format(PROJECT_ID, EXPERIMENT_RUN_ID), @@ -1048,7 +381,9 @@ def test_log_run_data_param_conflict(mocker): client = ExperimentClient(mocker.Mock()) with pytest.raises(ParamConflict, match=message): - client.log_run_data(PROJECT_ID, EXPERIMENT_RUN_ID, params=[PARAM]) + client.log_run_data( + PROJECT_ID, EXPERIMENT_RUN_ID, params=[mocker.Mock()] + ) def test_log_run_data_other_conflict(mocker): @@ -1059,7 +394,9 @@ def test_log_run_data_other_conflict(mocker): client = ExperimentClient(mocker.Mock()) with pytest.raises(Conflict): - client.log_run_data(PROJECT_ID, EXPERIMENT_RUN_ID, params=[PARAM]) + client.log_run_data( + PROJECT_ID, EXPERIMENT_RUN_ID, params=[mocker.Mock()] + ) def test_log_run_data_empty(mocker): @@ -1072,9 +409,8 @@ def test_log_run_data_empty(mocker): def test_update_run_info(mocker): - mocker.patch.object( - ExperimentClient, "_patch", return_value=EXPERIMENT_RUN - ) + run = mocker.Mock() + mocker.patch.object(ExperimentClient, "_patch", return_value=run) run_schema_mock = mocker.patch( "faculty.clients.experiment.ExperimentRunSchema" ) @@ -1083,49 +419,19 @@ def test_update_run_info(mocker): ) run_info_dump_mock = run_info_schema_mock.return_value.dump - client = ExperimentClient(mocker.Mock()) - returned_run = client.update_run_info( - PROJECT_ID, - EXPERIMENT_RUN_ID, - EXPERIMENT_RUN.status, - EXPERIMENT_RUN.ended_at, - ) - assert returned_run == EXPERIMENT_RUN - - run_schema_mock.assert_called_once_with() - run_info_schema_mock.assert_called_once_with() - run_info_dump_mock.assert_called_once_with( - {"status": EXPERIMENT_RUN.status, "ended_at": EXPERIMENT_RUN.ended_at} - ) - ExperimentClient._patch.assert_called_once_with( - "/project/{}/run/{}/info".format(PROJECT_ID, EXPERIMENT_RUN_ID), - run_schema_mock.return_value, - json=run_info_dump_mock.return_value, - ) - - -def test_update_run_info_status_only(mocker): - mocker.patch.object( - ExperimentClient, "_patch", return_value=EXPERIMENT_RUN - ) - run_schema_mock = mocker.patch( - "faculty.clients.experiment.ExperimentRunSchema" - ) - run_info_schema_mock = mocker.patch( - "faculty.clients.experiment.ExperimentRunInfoSchema" - ) - run_info_dump_mock = run_info_schema_mock.return_value.dump + status = mocker.Mock() + ended_at = mocker.Mock() client = ExperimentClient(mocker.Mock()) returned_run = client.update_run_info( - PROJECT_ID, EXPERIMENT_RUN_ID, status=EXPERIMENT_RUN.status + PROJECT_ID, EXPERIMENT_RUN_ID, status, ended_at ) - assert returned_run == EXPERIMENT_RUN + assert returned_run == run run_schema_mock.assert_called_once_with() run_info_schema_mock.assert_called_once_with() run_info_dump_mock.assert_called_once_with( - {"status": EXPERIMENT_RUN.status, "ended_at": None} + {"status": status, "ended_at": ended_at} ) ExperimentClient._patch.assert_called_once_with( "/project/{}/run/{}/info".format(PROJECT_ID, EXPERIMENT_RUN_ID), @@ -1134,111 +440,67 @@ def test_update_run_info_status_only(mocker): ) -def test_update_run_info_ended_at_only(mocker): - mocker.patch.object( - ExperimentClient, "_patch", return_value=EXPERIMENT_RUN - ) - run_schema_mock = mocker.patch( - "faculty.clients.experiment.ExperimentRunSchema" - ) - run_info_schema_mock = mocker.patch( - "faculty.clients.experiment.ExperimentRunInfoSchema" - ) - run_info_dump_mock = run_info_schema_mock.return_value.dump - - client = ExperimentClient(mocker.Mock()) - returned_run = client.update_run_info( - PROJECT_ID, EXPERIMENT_RUN_ID, ended_at=EXPERIMENT_RUN.ended_at - ) - assert returned_run == EXPERIMENT_RUN - - run_schema_mock.assert_called_once_with() - run_info_schema_mock.assert_called_once_with() - run_info_dump_mock.assert_called_once_with( - {"status": None, "ended_at": EXPERIMENT_RUN.ended_at} - ) - ExperimentClient._patch.assert_called_once_with( - "/project/{}/run/{}/info".format(PROJECT_ID, EXPERIMENT_RUN_ID), - run_schema_mock.return_value, - json=run_info_dump_mock.return_value, - ) - - -def test_update_run_info_empty(mocker): - mocker.patch.object( - ExperimentClient, "_patch", return_value=EXPERIMENT_RUN - ) - run_schema_mock = mocker.patch( - "faculty.clients.experiment.ExperimentRunSchema" - ) - run_info_schema_mock = mocker.patch( - "faculty.clients.experiment.ExperimentRunInfoSchema" - ) - run_info_dump_mock = run_info_schema_mock.return_value.dump - - client = ExperimentClient(mocker.Mock()) - returned_run = client.update_run_info(PROJECT_ID, EXPERIMENT_RUN_ID) - assert returned_run == EXPERIMENT_RUN - - run_schema_mock.assert_called_once_with() - run_info_schema_mock.assert_called_once_with() - run_info_dump_mock.assert_called_once_with( - {"status": None, "ended_at": None} - ) - ExperimentClient._patch.assert_called_once_with( - "/project/{}/run/{}/info".format(PROJECT_ID, EXPERIMENT_RUN_ID), - run_schema_mock.return_value, - json=run_info_dump_mock.return_value, - ) - - -def test_metric_history_schema(): - data = MetricHistorySchema().load(METRIC_HISTORY_BODY) - assert data == METRIC_HISTORY - - -def test_metric_history_schema_invalid(): - with pytest.raises(ValidationError): - MetricHistorySchema().load({}) - - def test_get_metric_history(mocker): - mocker.patch.object(ExperimentClient, "_get", return_value=METRIC_HISTORY) + key = mocker.Mock() + data_point_0 = mocker.Mock() + data_point_1 = mocker.Mock() + metric_history = mocker.Mock(key=key, history=[data_point_0, data_point_1]) + + mocker.patch.object(ExperimentClient, "_get", return_value=metric_history) metric_history_schema_mock = mocker.patch( "faculty.clients.experiment.MetricHistorySchema" ) client = ExperimentClient(mocker.Mock()) - - returned_metric_history = client.get_metric_history( - PROJECT_ID, EXPERIMENT_RUN_ID, METRIC_KEY + metrics = client.get_metric_history( + PROJECT_ID, EXPERIMENT_RUN_ID, "metric-key" ) - assert returned_metric_history == [METRIC] - metric_history_schema_mock.assert_called_once_with() + expected = [ + Metric( + key=key, + step=data_point_0.step, + timestamp=data_point_0.timestamp, + value=data_point_0.value, + ), + Metric( + key=key, + step=data_point_1.step, + timestamp=data_point_1.timestamp, + value=data_point_1.value, + ), + ] + assert metrics == expected + metric_history_schema_mock.assert_called_once_with() ExperimentClient._get.assert_called_once_with( - "/project/{}/run/{}/metric/{}/history".format( - PROJECT_ID, EXPERIMENT_RUN_ID, METRIC_KEY + "/project/{}/run/{}/metric/metric-key/history".format( + PROJECT_ID, EXPERIMENT_RUN_ID ), metric_history_schema_mock.return_value, ) def test_delete_runs(mocker): + delete_runs_response = mocker.Mock() mocker.patch.object( - ExperimentClient, "_post", return_value=DELETE_EXPERIMENT_RUNS_RESPONSE + ExperimentClient, "_post", return_value=delete_runs_response ) response_schema_mock = mocker.patch( "faculty.clients.experiment.DeleteExperimentRunsResponseSchema" ) - filter_dump_mock = mocker.patch.object(FilterSchema, "dump") + filter_schema_mock = mocker.patch( + "faculty.clients.experiment.FilterSchema" + ) + filter_dump_mock = filter_schema_mock.return_value.dump + run_ids = [uuid4(), uuid4()] client = ExperimentClient(mocker.Mock()) response = client.delete_runs(PROJECT_ID, run_ids) - assert response == DELETE_EXPERIMENT_RUNS_RESPONSE + assert response == delete_runs_response + expected_filter = CompoundFilter( LogicalOperator.OR, [ @@ -1276,28 +538,31 @@ def test_delete_runs_empty_list(mocker): client = ExperimentClient(mocker.Mock()) response = client.delete_runs(PROJECT_ID, run_ids=[]) - assert response == DeleteExperimentRunsResponse( - deleted_run_ids=[], conflicted_run_ids=[] - ) ExperimentClient._post.assert_not_called() + assert len(response.deleted_run_ids) == 0 + assert len(response.conflicted_run_ids) == 0 def test_restore_runs(mocker): + restore_runs_response = mocker.Mock() mocker.patch.object( - ExperimentClient, - "_post", - return_value=RESTORE_EXPERIMENT_RUNS_RESPONSE, + ExperimentClient, "_post", return_value=restore_runs_response ) response_schema_mock = mocker.patch( "faculty.clients.experiment.RestoreExperimentRunsResponseSchema" ) - filter_dump_mock = mocker.patch.object(FilterSchema, "dump") + filter_schema_mock = mocker.patch( + "faculty.clients.experiment.FilterSchema" + ) + filter_dump_mock = filter_schema_mock.return_value.dump + run_ids = [uuid4(), uuid4()] client = ExperimentClient(mocker.Mock()) response = client.restore_runs(PROJECT_ID, run_ids) - assert response == RESTORE_EXPERIMENT_RUNS_RESPONSE + assert response == restore_runs_response + expected_filter = CompoundFilter( LogicalOperator.OR, [ @@ -1335,7 +600,6 @@ def test_restore_runs_empty_list(mocker): client = ExperimentClient(mocker.Mock()) response = client.restore_runs(PROJECT_ID, run_ids=[]) - assert response == RestoreExperimentRunsResponse( - restored_run_ids=[], conflicted_run_ids=[] - ) ExperimentClient._post.assert_not_called() + assert len(response.restored_run_ids) == 0 + assert len(response.conflicted_run_ids) == 0 diff --git a/tests/clients/experiment/test_schemas.py b/tests/clients/experiment/test_schemas.py new file mode 100644 index 00000000..e5e3066b --- /dev/null +++ b/tests/clients/experiment/test_schemas.py @@ -0,0 +1,728 @@ +# Copyright 2018-2019 Faculty Science Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from datetime import datetime +from uuid import uuid4 + +import pytest +from marshmallow import ValidationError +from pytz import UTC + +from faculty.clients.experiment.models import ( + ComparisonOperator, + CompoundFilter, + DeleteExperimentRunsResponse, + DeletedAtFilter, + DurationSort, + Experiment, + ExperimentIdFilter, + ExperimentRun, + ExperimentRunStatus, + ListExperimentRunsResponse, + LogicalOperator, + Metric, + MetricDataPoint, + MetricFilter, + MetricHistory, + MetricSort, + Page, + Pagination, + Param, + ParamFilter, + ParamSort, + ProjectIdFilter, + RestoreExperimentRunsResponse, + RunIdFilter, + RunNumberSort, + RunQuery, + SortOrder, + StartedAtSort, + Tag, + TagFilter, + TagSort, +) +from faculty.clients.experiment.schemas import ( + CreateRunSchema, + DeleteExperimentRunsResponseSchema, + ExperimentRunDataSchema, + ExperimentRunSchema, + ExperimentSchema, + FilterSchema, + ListExperimentRunsResponseSchema, + MetricHistorySchema, + MetricSchema, + PageSchema, + PaginationSchema, + ParamSchema, + RestoreExperimentRunsResponseSchema, + RunQuerySchema, + SortSchema, + TagSchema, +) + +PROJECT_ID = uuid4() +EXPERIMENT_ID = 661 +EXPERIMENT_RUN_ID = uuid4() +EXPERIMENT_RUN_NUMBER = 3 +EXPERIMENT_RUN_NAME = "run name" +PARENT_RUN_ID = uuid4() +CREATED_AT = datetime(2018, 3, 10, 11, 32, 6, 247000, tzinfo=UTC) +CREATED_AT_STRING = "2018-03-10T11:32:06.247Z" +LAST_UPDATED_AT = datetime(2018, 3, 10, 11, 32, 30, 172000, tzinfo=UTC) +LAST_UPDATED_AT_STRING = "2018-03-10T11:32:30.172Z" +DELETED_AT = datetime(2018, 3, 10, 11, 37, 42, 482000, tzinfo=UTC) +DELETED_AT_STRING = "2018-03-10T11:37:42.482Z" +DELETED_AT_STRING_PYTHON = "2018-03-10T11:37:42.482000+00:00" + +EXPERIMENT = Experiment( + id=EXPERIMENT_ID, + name="experiment name", + description="experiment description", + artifact_location="https://example.com", + created_at=CREATED_AT, + last_updated_at=LAST_UPDATED_AT, + deleted_at=DELETED_AT, +) +EXPERIMENT_BODY = { + "experimentId": EXPERIMENT_ID, + "name": EXPERIMENT.name, + "description": EXPERIMENT.description, + "artifactLocation": EXPERIMENT.artifact_location, + "createdAt": CREATED_AT_STRING, + "lastUpdatedAt": LAST_UPDATED_AT_STRING, + "deletedAt": DELETED_AT_STRING, +} + +RUN_ID = uuid4() +RUN_STARTED_AT = datetime(2018, 3, 10, 11, 39, 12, 110000, tzinfo=UTC) +RUN_STARTED_AT_NO_TIMEZONE = datetime(2018, 3, 10, 11, 39, 12, 110000) +RUN_STARTED_AT_STRING_PYTHON = "2018-03-10T11:39:12.110000+00:00" +RUN_STARTED_AT_STRING_JAVA = "2018-03-10T11:39:12.11Z" +RUN_ENDED_AT = datetime(2018, 3, 10, 11, 39, 15, 110000, tzinfo=UTC) +RUN_ENDED_AT_STRING = "2018-03-10T11:39:15.11Z" + +TAG = Tag(key="tag-key", value="tag-value") +TAG_BODY = {"key": "tag-key", "value": "tag-value"} + +OTHER_TAG = Tag(key="other-tag-key", value="other-tag-value") +OTHER_TAG_BODY = {"key": "other-tag-key", "value": "other-tag-value"} + +PARAM = Param(key="param-key", value="param-value") +PARAM_BODY = {"key": "param-key", "value": "param-value"} + +METRIC_KEY = "metric-key" +METRIC = Metric( + key=METRIC_KEY, + value=123.0, + timestamp=datetime(2018, 3, 12, 16, 20, 22, 122000, tzinfo=UTC), + step=0, +) +METRIC_BODY = { + "key": METRIC.key, + "value": METRIC.value, + "timestamp": "2018-03-12T16:20:22.122000+00:00", + "step": METRIC.step, +} + +METRIC_DATA_POINT = MetricDataPoint( + value=METRIC.value, timestamp=METRIC.timestamp, step=METRIC.step +) +METRIC_DATA_POINT_BODY = { + "value": METRIC_BODY["value"], + "timestamp": METRIC_BODY["timestamp"], + "step": METRIC_BODY["step"], +} + +METRIC_HISTORY = MetricHistory( + original_size=1, + subsampled=False, + key=METRIC_KEY, + history=[METRIC_DATA_POINT], +) +METRIC_HISTORY_BODY = { + "originalSize": METRIC_HISTORY.original_size, + "subsampled": METRIC_HISTORY.subsampled, + "key": METRIC_HISTORY.key, + "history": [METRIC_DATA_POINT_BODY], +} + +EXPERIMENT_RUN = ExperimentRun( + id=EXPERIMENT_RUN_ID, + run_number=EXPERIMENT_RUN_NUMBER, + name=EXPERIMENT_RUN_NAME, + parent_run_id=PARENT_RUN_ID, + experiment_id=EXPERIMENT.id, + artifact_location="faculty:", + status=ExperimentRunStatus.RUNNING, + started_at=RUN_STARTED_AT, + ended_at=RUN_ENDED_AT, + deleted_at=DELETED_AT, + tags=[TAG], + params=[PARAM], + metrics=[METRIC], +) +EXPERIMENT_RUN_BODY = { + "experimentId": EXPERIMENT.id, + "runId": str(EXPERIMENT_RUN_ID), + "runNumber": EXPERIMENT_RUN_NUMBER, + "name": EXPERIMENT_RUN_NAME, + "parentRunId": str(PARENT_RUN_ID), + "artifactLocation": "faculty:", + "status": "running", + "startedAt": RUN_STARTED_AT_STRING_JAVA, + "endedAt": RUN_ENDED_AT_STRING, + "deletedAt": DELETED_AT_STRING, + "tags": [TAG_BODY], + "metrics": [METRIC_BODY], + "params": [PARAM_BODY], +} + +EXPERIMENT_RUN_DATA_BODY = { + "metrics": [METRIC_BODY], + "params": [PARAM_BODY], + "tags": [TAG_BODY], +} + + +PAGE = Page(start=3, limit=10) +PAGE_BODY = {"start": PAGE.start, "limit": PAGE.limit} + +PAGINATION = Pagination( + start=20, + size=10, + previous=Page(start=10, limit=10), + next=Page(start=30, limit=10), +) +PAGINATION_BODY = { + "start": PAGINATION.start, + "size": PAGINATION.size, + "previous": { + "start": PAGINATION.previous.start, + "limit": PAGINATION.previous.limit, + }, + "next": {"start": PAGINATION.next.start, "limit": PAGINATION.next.limit}, +} + +LIST_EXPERIMENT_RUNS_RESPONSE = ListExperimentRunsResponse( + runs=[EXPERIMENT_RUN], pagination=PAGINATION +) +LIST_EXPERIMENT_RUNS_RESPONSE_BODY = { + "runs": [EXPERIMENT_RUN_BODY], + "pagination": PAGINATION_BODY, +} + +DELETE_EXPERIMENT_RUNS_RESPONSE = DeleteExperimentRunsResponse( + deleted_run_ids=[uuid4(), uuid4()], conflicted_run_ids=[uuid4(), uuid4()] +) +DELETE_EXPERIMENT_RUNS_RESPONSE_BODY = { + "deletedRunIds": [ + str(run_id) + for run_id in DELETE_EXPERIMENT_RUNS_RESPONSE.deleted_run_ids + ], + "conflictedRunIds": [ + str(run_id) + for run_id in DELETE_EXPERIMENT_RUNS_RESPONSE.conflicted_run_ids + ], +} + +RESTORE_EXPERIMENT_RUNS_RESPONSE = RestoreExperimentRunsResponse( + restored_run_ids=[uuid4(), uuid4()], conflicted_run_ids=[uuid4(), uuid4()] +) +RESTORE_EXPERIMENT_RUNS_RESPONSE_BODY = { + "restoredRunIds": [ + str(run_id) + for run_id in RESTORE_EXPERIMENT_RUNS_RESPONSE.restored_run_ids + ], + "conflictedRunIds": [ + str(run_id) + for run_id in RESTORE_EXPERIMENT_RUNS_RESPONSE.conflicted_run_ids + ], +} + + +def test_experiment_schema(): + data = ExperimentSchema().load(EXPERIMENT_BODY) + assert data == EXPERIMENT + + +def test_experiment_schema_nullable_deleted_at(): + body = EXPERIMENT_BODY.copy() + body["deletedAt"] = None + data = ExperimentSchema().load(body) + assert data.deleted_at is None + + +def test_experiment_schema_invalid(): + with pytest.raises(ValidationError): + ExperimentSchema().load({}) + + +def test_experiment_run_schema(): + data = ExperimentRunSchema().load(EXPERIMENT_RUN_BODY) + assert data == EXPERIMENT_RUN + + +@pytest.mark.parametrize( + "data_key, field", + [ + ("parentRunId", "parent_run_id"), + ("endedAt", "ended_at"), + ("deletedAt", "deleted_at"), + ], +) +def test_experiment_run_schema_nullable_field(data_key, field): + body = EXPERIMENT_RUN_BODY.copy() + del body[data_key] + data = ExperimentRunSchema().load(body) + assert getattr(data, field) is None + + +@pytest.mark.parametrize("parent_run_id", [None, PARENT_RUN_ID]) +@pytest.mark.parametrize( + "started_at", + [RUN_STARTED_AT, RUN_STARTED_AT_NO_TIMEZONE], + ids=["timezone", "no timezone"], +) +@pytest.mark.parametrize("artifact_location", [None, "faculty:project-id"]) +@pytest.mark.parametrize("tags", [[], [{"key": "key", "value": "value"}]]) +def test_create_run_schema(parent_run_id, started_at, artifact_location, tags): + data = CreateRunSchema().dump( + { + "name": EXPERIMENT_RUN_NAME, + "parent_run_id": parent_run_id, + "started_at": started_at, + "artifact_location": artifact_location, + "tags": tags, + } + ) + assert data == { + "name": EXPERIMENT_RUN_NAME, + "parentRunId": None if parent_run_id is None else str(parent_run_id), + "startedAt": RUN_STARTED_AT_STRING_PYTHON, + "artifactLocation": artifact_location, + "tags": tags, + } + + +def test_metric_schema(): + data = MetricSchema().load(METRIC_BODY) + assert data == METRIC + + +def test_param_schema(): + data = ParamSchema().load(PARAM_BODY) + assert data == PARAM + + +def test_tag_schema(): + data = TagSchema().load(TAG_BODY) + assert data == TAG + + +def test_tag_schema_dump(): + data = TagSchema().dump(TAG_BODY) + assert data == TAG_BODY + + +def test_experiment_run_data_schema(): + data = ExperimentRunDataSchema().dump( + {"metrics": [METRIC], "params": [PARAM], "tags": [TAG]} + ) + assert data == EXPERIMENT_RUN_DATA_BODY + + +def test_experiment_run_data_schema_empty(): + data = ExperimentRunDataSchema().dump({}) + assert data == {} + + +def test_experiment_run_data_schema_multiple(): + data = ExperimentRunDataSchema().dump({"tags": [TAG, OTHER_TAG]}) + assert data == {"tags": [TAG_BODY, OTHER_TAG_BODY]} + + +PROJECT_ID_FILTER = ProjectIdFilter(ComparisonOperator.EQUAL_TO, PROJECT_ID) +PROJECT_ID_FILTER_BODY = { + "by": "projectId", + "operator": "eq", + "value": str(PROJECT_ID), +} + +TAG_FILTER = TagFilter("tag-key", ComparisonOperator.EQUAL_TO, "tag-value") +TAG_FILTER_BODY = { + "by": "tag", + "key": "tag-key", + "operator": "eq", + "value": "tag-value", +} + + +DEFINED_TEST_CASES = [ + (ComparisonOperator.DEFINED, False, "defined", False), + (ComparisonOperator.DEFINED, True, "defined", True), + (ComparisonOperator.DEFINED, 0, "defined", False), + (ComparisonOperator.DEFINED, 1, "defined", True), +] + + +def discrete_test_cases(value, expected): + return DEFINED_TEST_CASES + [ + (ComparisonOperator.EQUAL_TO, value, "eq", expected), + (ComparisonOperator.NOT_EQUAL_TO, value, "ne", expected), + ] + + +def continuous_test_cases(value, expected): + return discrete_test_cases(value, expected) + [ + (ComparisonOperator.GREATER_THAN, value, "gt", expected), + (ComparisonOperator.GREATER_THAN_OR_EQUAL_TO, value, "ge", expected), + (ComparisonOperator.LESS_THAN, value, "lt", expected), + (ComparisonOperator.LESS_THAN_OR_EQUAL_TO, value, "le", expected), + ] + + +@pytest.mark.parametrize( + "operator, value, expected_operator, expected_value", + discrete_test_cases(PROJECT_ID, str(PROJECT_ID)), +) +def test_filter_schema_project_id( + operator, value, expected_operator, expected_value +): + filter = ProjectIdFilter(operator, value) + data = FilterSchema().dump(filter) + assert data == { + "by": "projectId", + "operator": expected_operator, + "value": expected_value, + } + + +@pytest.mark.parametrize( + "operator, value, expected_operator, expected_value", + discrete_test_cases(EXPERIMENT_ID, EXPERIMENT_ID), +) +def test_filter_schema_experiment_id( + operator, value, expected_operator, expected_value +): + filter = ExperimentIdFilter(operator, value) + data = FilterSchema().dump(filter) + assert data == { + "by": "experimentId", + "operator": expected_operator, + "value": expected_value, + } + + +@pytest.mark.parametrize( + "operator, value, expected_operator, expected_value", + discrete_test_cases(RUN_ID, str(RUN_ID)), +) +def test_filter_schema_run_id( + operator, value, expected_operator, expected_value +): + filter = RunIdFilter(operator, value) + data = FilterSchema().dump(filter) + assert data == { + "by": "runId", + "operator": expected_operator, + "value": expected_value, + } + + +@pytest.mark.parametrize( + "operator, value, expected_operator, expected_value", + continuous_test_cases(DELETED_AT, DELETED_AT_STRING_PYTHON), +) +def test_filter_schema_deleted_at( + operator, value, expected_operator, expected_value +): + filter = DeletedAtFilter(operator, value) + data = FilterSchema().dump(filter) + assert data == { + "by": "deletedAt", + "operator": expected_operator, + "value": expected_value, + } + + +@pytest.mark.parametrize( + "operator, value, expected_operator, expected_value", + discrete_test_cases("tag-value", "tag-value"), +) +def test_filter_schema_tag(operator, value, expected_operator, expected_value): + filter = TagFilter("tag-key", operator, value) + data = FilterSchema().dump(filter) + assert data == { + "by": "tag", + "key": "tag-key", + "operator": expected_operator, + "value": expected_value, + } + + +@pytest.mark.parametrize( + "operator, value, expected_operator, expected_value", + discrete_test_cases("param-value", "param-value") + + continuous_test_cases(123.2, 123.2), +) +def test_filter_schema_param( + operator, value, expected_operator, expected_value +): + filter = ParamFilter("param-key", operator, value) + data = FilterSchema().dump(filter) + assert data == { + "by": "param", + "key": "param-key", + "operator": expected_operator, + "value": expected_value, + } + + +@pytest.mark.parametrize( + "operator, value, expected_operator, expected_value", + continuous_test_cases(45.6, 45.6), +) +def test_filter_schema_metric( + operator, value, expected_operator, expected_value +): + filter = MetricFilter("metric-key", operator, value) + data = FilterSchema().dump(filter) + assert data == { + "by": "metric", + "key": "metric-key", + "operator": expected_operator, + "value": expected_value, + } + + +@pytest.mark.parametrize( + "filter_type", + [ProjectIdFilter, ExperimentIdFilter, RunIdFilter, DeletedAtFilter], +) +def test_filter_schema_invalid_value_no_key(filter_type): + filter = filter_type(ComparisonOperator.EQUAL_TO, "invalid") + with pytest.raises(ValidationError): + FilterSchema().dump(filter) + + +@pytest.mark.parametrize( + "filter_type, value", [(ParamFilter, None), (MetricFilter, "invalid")] +) +def test_filter_schema_invalid_value_with_key(filter_type, value): + filter = filter_type("key", ComparisonOperator.EQUAL_TO, value) + with pytest.raises(ValidationError): + FilterSchema().dump(filter) + + +@pytest.mark.parametrize( + "filter_type, value", + [ + (ProjectIdFilter, PROJECT_ID), + (ExperimentIdFilter, EXPERIMENT_ID), + (RunIdFilter, RUN_ID), + ], +) +@pytest.mark.parametrize( + "operator", + [ + ComparisonOperator.GREATER_THAN, + ComparisonOperator.GREATER_THAN_OR_EQUAL_TO, + ComparisonOperator.LESS_THAN, + ComparisonOperator.LESS_THAN_OR_EQUAL_TO, + ], +) +def test_filter_schema_invalid_operator_no_key(filter_type, value, operator): + filter = filter_type(operator, value) + with pytest.raises(ValidationError, match="Not a discrete operator"): + FilterSchema().dump(filter) + + +@pytest.mark.parametrize( + "filter_type, value", + [(TagFilter, "tag-value"), (ParamFilter, "param-string-value")], +) +@pytest.mark.parametrize( + "operator", + [ + ComparisonOperator.GREATER_THAN, + ComparisonOperator.GREATER_THAN_OR_EQUAL_TO, + ComparisonOperator.LESS_THAN, + ComparisonOperator.LESS_THAN_OR_EQUAL_TO, + ], +) +def test_filter_schema_invalid_operator_with_key(filter_type, value, operator): + filter = filter_type("key", operator, value) + with pytest.raises(ValidationError, match="Not a discrete operator"): + FilterSchema().dump(filter) + + +@pytest.mark.parametrize( + "operator, expected_operator", + [(LogicalOperator.AND, "and"), (LogicalOperator.OR, "or")], +) +def test_filter_schema_compound(operator, expected_operator): + filter = CompoundFilter(operator, [PROJECT_ID_FILTER, TAG_FILTER]) + data = FilterSchema().dump(filter) + assert data == { + "operator": expected_operator, + "conditions": [PROJECT_ID_FILTER_BODY, TAG_FILTER_BODY], + } + + +def test_filter_schema_nested(): + filter = CompoundFilter( + LogicalOperator.AND, + [ + CompoundFilter( + LogicalOperator.AND, [PROJECT_ID_FILTER, TAG_FILTER] + ), + CompoundFilter( + LogicalOperator.OR, [TAG_FILTER, PROJECT_ID_FILTER] + ), + ], + ) + data = FilterSchema().dump(filter) + assert data == { + "operator": "and", + "conditions": [ + { + "operator": "and", + "conditions": [PROJECT_ID_FILTER_BODY, TAG_FILTER_BODY], + }, + { + "operator": "or", + "conditions": [TAG_FILTER_BODY, PROJECT_ID_FILTER_BODY], + }, + ], + } + + +@pytest.mark.parametrize( + "sort_type, by", + [ + (StartedAtSort, "startedAt"), + (RunNumberSort, "runNumber"), + (DurationSort, "duration"), + ], +) +@pytest.mark.parametrize( + "order, expected_order", [(SortOrder.ASC, "asc"), (SortOrder.DESC, "desc")] +) +def test_sort_schema_no_tag(sort_type, by, order, expected_order): + sort = sort_type(order) + data = SortSchema().dump(sort) + assert data == {"by": by, "order": expected_order} + + +@pytest.mark.parametrize( + "sort_type, by", + [(TagSort, "tag"), (ParamSort, "param"), (MetricSort, "metric")], +) +@pytest.mark.parametrize( + "order, expected_order", [(SortOrder.ASC, "asc"), (SortOrder.DESC, "desc")] +) +def test_sort_schema_with_tag(sort_type, by, order, expected_order): + sort = sort_type("sort-key", order) + data = SortSchema().dump(sort) + assert data == {"by": by, "key": "sort-key", "order": expected_order} + + +def test_run_query_schema(mocker): + mocker.patch.object(FilterSchema, "dump") + mocker.patch.object(SortSchema, "dump") + mocker.patch.object(PageSchema, "dump") + + filter = mocker.Mock() + sorts = [mocker.Mock(), mocker.Mock()] + page = mocker.Mock() + + run_query = RunQuery(filter, sorts, page) + data = RunQuerySchema().dump(run_query) + + assert data == { + "filter": FilterSchema.dump.return_value, + "sort": [SortSchema.dump.return_value, SortSchema.dump.return_value], + "page": PageSchema.dump.return_value, + } + + +def test_run_query_schema_defaults(): + run_query = RunQuery(None, None, None) + data = RunQuerySchema().dump(run_query) + assert data == {"filter": None, "sort": None, "page": None} + + +def test_list_runs_schema(mocker): + data = ListExperimentRunsResponseSchema().load( + LIST_EXPERIMENT_RUNS_RESPONSE_BODY + ) + assert data == LIST_EXPERIMENT_RUNS_RESPONSE + + +def test_page_schema_load(): + data = PageSchema().load(PAGE_BODY) + assert data == PAGE + + +def test_page_schema_dump(): + data = PageSchema().dump(PAGE) + assert data == PAGE_BODY + + +def test_pagination_schema(): + data = PaginationSchema().load(PAGINATION_BODY) + assert data == PAGINATION + + +@pytest.mark.parametrize("field", ["previous", "next"]) +def test_pagination_schema_nullable_field(field): + body = PAGINATION_BODY.copy() + del body[field] + data = PaginationSchema().load(body) + assert getattr(data, field) is None + + +def test_delete_experiment_runs_response_schema(mocker): + data = DeleteExperimentRunsResponseSchema().load( + DELETE_EXPERIMENT_RUNS_RESPONSE_BODY + ) + assert data == DELETE_EXPERIMENT_RUNS_RESPONSE + + +def test_delete_experiment_runs_response_schema_invalid(mocker): + with pytest.raises(ValidationError): + DeleteExperimentRunsResponseSchema().load({}) + + +def test_restore_experiment_runs_response_schema(mocker): + data = RestoreExperimentRunsResponseSchema().load( + RESTORE_EXPERIMENT_RUNS_RESPONSE_BODY + ) + assert data == RESTORE_EXPERIMENT_RUNS_RESPONSE + + +def test_restore_experiment_runs_response_schema_invalid(mocker): + with pytest.raises(ValidationError): + RestoreExperimentRunsResponseSchema().load({}) + + +def test_metric_history_schema(): + data = MetricHistorySchema().load(METRIC_HISTORY_BODY) + assert data == METRIC_HISTORY + + +def test_metric_history_schema_invalid(): + with pytest.raises(ValidationError): + MetricHistorySchema().load({}) From 21a64ca523d1c07bab9405b91e26f4967f2f2a86 Mon Sep 17 00:00:00 2001 From: Andrew Crozier Date: Wed, 12 Jun 2019 13:21:27 +0100 Subject: [PATCH 52/60] Move client to own module --- faculty/clients/experiment/__init__.py | 558 +---------------- faculty/clients/experiment/client.py | 570 ++++++++++++++++++ .../{test_init.py => test_client.py} | 48 +- 3 files changed, 603 insertions(+), 573 deletions(-) create mode 100644 faculty/clients/experiment/client.py rename tests/clients/experiment/{test_init.py => test_client.py} (92%) diff --git a/faculty/clients/experiment/__init__.py b/faculty/clients/experiment/__init__.py index 798a0154..3e322a81 100644 --- a/faculty/clients/experiment/__init__.py +++ b/faculty/clients/experiment/__init__.py @@ -12,558 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from faculty.clients.base import BaseClient, Conflict -from faculty.clients.experiment.models import ( - ComparisonOperator, - CompoundFilter, - DeleteExperimentRunsResponse, - DeletedAtFilter, - ExperimentIdFilter, - LifecycleStage, - ListExperimentRunsResponse, - LogicalOperator, - Metric, - Page, - Pagination, - RestoreExperimentRunsResponse, - RunIdFilter, - RunQuery, +from faculty.clients.experiment.client import ( # noqa: F401 + ExperimentClient, + ExperimentDeleted, + ExperimentNameConflict, + ParamConflict, ) -from faculty.clients.experiment.schemas import ( - CreateRunSchema, - DeleteExperimentRunsResponseSchema, - ExperimentRunDataSchema, - ExperimentRunInfoSchema, - ExperimentRunSchema, - ExperimentSchema, - FilterSchema, - ListExperimentRunsResponseSchema, - MetricHistorySchema, - RestoreExperimentRunsResponseSchema, - RunQuerySchema, -) - - -class ExperimentNameConflict(Exception): - def __init__(self, name): - tpl = "An experiment with name '{}' already exists in that project" - message = tpl.format(name) - super(ExperimentNameConflict, self).__init__(message) - - -class ParamConflict(Exception): - def __init__(self, message, conflicting_params=None): - super(ParamConflict, self).__init__(message) - if conflicting_params is None: - self.conflicting_params = [] - else: - self.conflicting_params = conflicting_params - - -class ExperimentDeleted(Exception): - def __init__(self, message, experiment_id): - super(ExperimentDeleted, self).__init__(message) - self.experiment_id = experiment_id - - -class ExperimentClient(BaseClient): - - SERVICE_NAME = "atlas" - - def create( - self, project_id, name, description=None, artifact_location=None - ): - """Create an experiment. - - Parameters - ---------- - project_id : uuid.UUID - name : str - description : str, optional - artifact_location : str, optional - - Returns - ------- - Experiment - - Raises - ------ - ExperimentNameConflict - When an experiment of the provided name already exists in the - project. - """ - endpoint = "/project/{}/experiment".format(project_id) - payload = { - "name": name, - "description": description, - "artifactLocation": artifact_location, - } - try: - return self._post(endpoint, ExperimentSchema(), json=payload) - except Conflict as err: - if err.error_code == "experiment_name_conflict": - raise ExperimentNameConflict(name) - else: - raise - - def get(self, project_id, experiment_id): - """Get a specified experiment. - - Parameters - ---------- - project_id : uuid.UUID - experiment_id : int - - Returns - ------- - Experiment - """ - endpoint = "/project/{}/experiment/{}".format( - project_id, experiment_id - ) - return self._get(endpoint, ExperimentSchema()) - - def list(self, project_id, lifecycle_stage=None): - """List the experiments in a project. - - Parameters - ---------- - project_id : uuid.UUID - lifecycle_stage : LifecycleStage, optional - To filter experiments in the given lifecycle stage only - (ACTIVE | DELETED). By default, all experiments in the - project are returned. - - Returns - ------- - List[Experiment] - """ - query_params = {} - if lifecycle_stage is not None: - query_params["lifecycleStage"] = lifecycle_stage.value - endpoint = "/project/{}/experiment".format(project_id) - return self._get( - endpoint, ExperimentSchema(many=True), params=query_params - ) - - def update(self, project_id, experiment_id, name=None, description=None): - """Update the name and/or description of an experiment. - - Parameters - ---------- - project_id : uuid.UUID - experiment_id : int - name : str, optional - The new name of the experiment. If not provided, the name will not - be modified. - description : str, optional - The new description of the experiment. If not provided, the - description will not be modified. - - Raises - ------ - ExperimentNameConflict - When an experiment of the provided name already exists in the - project. - """ - endpoint = "/project/{}/experiment/{}".format( - project_id, experiment_id - ) - payload = {"name": name, "description": description} - try: - self._patch_raw(endpoint, json=payload) - except Conflict as err: - if err.error_code == "experiment_name_conflict": - raise ExperimentNameConflict(name) - else: - raise - - def delete(self, project_id, experiment_id): - """Delete a specified experiment. - - Parameters - ---------- - project_id : uuid.UUID - experiment_id : int - """ - endpoint = "/project/{}/experiment/{}".format( - project_id, experiment_id - ) - self._delete_raw(endpoint) - - def restore(self, project_id, experiment_id): - """Restore a specified experiment. - - Parameters - ---------- - project_id : uuid.UUID - experiment_id : int - """ - endpoint = "/project/{}/experiment/{}/restore".format( - project_id, experiment_id - ) - self._put_raw(endpoint) - - def create_run( - self, - project_id, - experiment_id, - name, - started_at, - parent_run_id=None, - artifact_location=None, - tags=None, - ): - """Create a run in a project. - - Parameters - ---------- - project_id : uuid.UUID - experiment_id : int - name : str - started_at : datetime.datetime - Time at which the run was started. If the datetime does not have a - timezone, it will be assumed to be in UTC. - parent_run_id : uuid.UUID, optional - The ID of the parent run, if any. - artifact_location: str, optional - The location of the artifact repository to use for this run. - If omitted, the value of `artifact_location` for the experiment - will be used. - tags: List[Tag] - - Returns - ------- - ExperimentRun - - Raises - ------ - ExperimentDeleted - When the run that is being updated refers to an experiment that is - deleted - """ - if tags is None: - tags = [] - - endpoint = "/project/{}/experiment/{}/run".format( - project_id, experiment_id - ) - payload = CreateRunSchema().dump( - { - "name": name, - "parent_run_id": parent_run_id, - "started_at": started_at, - "artifact_location": artifact_location, - "tags": tags, - } - ) - try: - return self._post(endpoint, ExperimentRunSchema(), json=payload) - except Conflict as err: - if err.error_code == "experiment_deleted": - raise ExperimentDeleted( - err.error, err.response.json()["experimentId"] - ) - else: - raise - - def get_run(self, project_id, run_id): - """Get a specified experiment run. - - Parameters - ---------- - project_id : uuid.UUID - run_id : uuid.UUID - - Returns - ------- - ExperimentRun - """ - endpoint = "/project/{}/run/{}".format(project_id, run_id) - return self._get(endpoint, ExperimentRunSchema()) - - def list_runs( - self, - project_id, - experiment_ids=None, - lifecycle_stage=None, - start=None, - limit=None, - ): - """List experiment runs. - - This method returns pages of runs. If less than the full number of runs - for the job is returned, the ``next`` page of the returned response - object will not be ``None``: - - >>> response = client.list_runs(project_id) - >>> response.pagination.next - Page(start=10, limit=10) - - Get all experiment runs by making successive calls to ``list_runs``, - passing the ``start`` and ``limit`` of the ``next`` page each time - until ``next`` is returned as ``None``. - - Parameters - ---------- - project_id : uuid.UUID - experiment_ids : List[int], optional - To filter runs of experiments with the given IDs only. If an empty - list is passed, a result with an empty list of runs is returned. - By default, runs from all experiments are returned. - lifecycle_stage: LifecycleStage, optional - To filter runs of experiments in a specific lifecycle stage only. - By default, runs in any stage are returned. - start : int, optional - The (zero-indexed) starting point of runs to retrieve. - limit : int, optional - The maximum number of runs to retrieve. - - Returns - ------- - ListExperimentRunsResponse - """ - - experiment_ids_filter = None - lifecycle_filter = None - filter = None - - if experiment_ids is not None: - if len(experiment_ids) == 0: - return ListExperimentRunsResponse( - runs=[], - pagination=Pagination( - start=0, size=0, previous=None, next=None - ), - ) - experiment_id_filters = [ - ExperimentIdFilter(ComparisonOperator.EQUAL_TO, experiment_id) - for experiment_id in experiment_ids - ] - experiment_ids_filter = CompoundFilter( - LogicalOperator.OR, experiment_id_filters - ) - if lifecycle_stage is not None: - lifecycle_filter = DeletedAtFilter( - ComparisonOperator.DEFINED, - lifecycle_stage == LifecycleStage.DELETED, - ) - - if experiment_ids_filter is not None and lifecycle_filter is not None: - filter = CompoundFilter( - LogicalOperator.AND, [experiment_ids_filter, lifecycle_filter] - ) - elif experiment_ids_filter is not None: - filter = experiment_ids_filter - elif lifecycle_filter is not None: - filter = lifecycle_filter - - return self.query_runs(project_id, filter, None, start, limit) - - def query_runs( - self, project_id, filter=None, sort=None, start=None, limit=None - ): - """Query experiment runs. - - This method returns pages of runs. If less than the full number of runs - for the job is returned, the ``next`` page of the returned response - object will not be ``None``: - - >>> response = client.query_runs(project_id) - >>> response.pagination.next - Page(start=10, limit=10) - - Get all experiment runs by making successive calls to ``query_runs``, - passing the ``start`` and ``limit`` of the ``next`` page each time - until ``next`` is returned as ``None``. - - Parameters - ---------- - project_id : uuid.UUID - filter: SingleFilter or CompoundFilter, optional - To filter runs of experiments with the given filter. By default, - runs from all experiments are returned. - sort: List[Sort], optional - Runs are order using the conditions in sort. The relative - importance of each condition gradually decreases in order. - By default, experiment runs are sorted by their startedAt value. - start : int, optional - The (zero-indexed) starting point of runs to retrieve. - limit : int, optional - The maximum number of runs to retrieve. - - Returns - ------- - ListExperimentRunsResponse - """ - endpoint = "/project/{}/run/query".format(project_id) - page = None - if start is not None and limit is not None: - page = Page(start, limit) - payload = RunQuerySchema().dump(RunQuery(filter, sort, page)) - return self._post( - endpoint, ListExperimentRunsResponseSchema(), json=payload - ) - - def log_run_data( - self, project_id, run_id, metrics=None, params=None, tags=None - ): - """Update the data of a run. - - Parameters - ---------- - project_id : uuid.UUID - run_id : uuid.UUID - metrics : List[Metric], optional - Each metric will be inserted. - params : List[Param], optional - Each param will be inserted. Note that on a name conflict the - entire operation will be rejected. - tags : List[Tag], optional - Each tag be upserted. - - Raises - ------ - ParamConflict - When a provided param already exists and has a different value than - was specified. - """ - if all(kwarg is None for kwarg in [metrics, params, tags]): - return - endpoint = "/project/{}/run/{}/data".format(project_id, run_id) - payload = ExperimentRunDataSchema().dump( - {"metrics": metrics, "params": params, "tags": tags} - ) - try: - self._patch_raw(endpoint, json=payload) - except Conflict as err: - if err.error_code == "conflicting_params": - raise ParamConflict( - err.error, err.response.json()["parameterKeys"] - ) - else: - raise - - def update_run_info(self, project_id, run_id, status=None, ended_at=None): - """Update the status and end time of a run. - - Parameters - ---------- - project_id : uuid.UUID - run_id : uuid.UUID - status: ExperimentRunStatus, optional - ended_at: datetime, optional - - Returns - ------- - ExperimentRun - """ - endpoint = "/project/{}/run/{}/info".format(project_id, run_id) - payload = ExperimentRunInfoSchema().dump( - {"status": status, "ended_at": ended_at} - ) - return self._patch(endpoint, ExperimentRunSchema(), json=payload) - - def get_metric_history(self, project_id, run_id, key): - """Get the history of a metric. - - Parameters - ---------- - project_id : uuid.UUID - run_id : uuid.UUID - key: string - - Returns - ------- - List[Metric], ordered by timestamp and value - """ - endpoint = "/project/{}/run/{}/metric/{}/history".format( - project_id, run_id, key - ) - metric_history = self._get(endpoint, MetricHistorySchema()) - return [ - Metric( - key=metric_history.key, - value=metric_data_point.value, - timestamp=metric_data_point.timestamp, - step=metric_data_point.step, - ) - for metric_data_point in metric_history.history - ] - - def delete_runs(self, project_id, run_ids=None): - """Delete experiment runs. - - Parameters - ---------- - project_id : uuid.UUID - run_ids : List[uuid.UUID], optional - A list of run IDs to delete. If not specified, all runs in the - project will be deleted. If an empty list is passed, no runs - will be deleted. - - Returns - ------- - DeleteExperimentRunsResponse - Containing lists of successfully deleted and conflicting (already - deleted) run IDs. - """ - endpoint = "/project/{}/run/delete/query".format(project_id) - - if run_ids is None: - # Delete all runs in project - payload = {} # No filter - elif len(run_ids) == 0: - return DeleteExperimentRunsResponse( - deleted_run_ids=[], conflicted_run_ids=[] - ) - else: - run_id_filters = [ - RunIdFilter(ComparisonOperator.EQUAL_TO, run_id) - for run_id in run_ids - ] - filter = CompoundFilter(LogicalOperator.OR, run_id_filters) - payload = {"filter": FilterSchema().dump(filter)} - - return self._post( - endpoint, DeleteExperimentRunsResponseSchema(), json=payload - ) - - def restore_runs(self, project_id, run_ids=None): - """Restore experiment runs. - - Parameters - ---------- - project_id : uuid.UUID - run_ids : List[uuid.UUID], optional - A list of run IDs to restore. If not specified, all runs in the - project will be restored. If an empty list is passed, no runs - will be restored. - - Returns - ------- - RestoreExperimentRunsResponse - Containing lists of successfully restored and conflicting (already - active) run IDs. - """ - endpoint = "/project/{}/run/restore/query".format(project_id) - - if run_ids is None: - # Restore all runs in project - payload = {} # No filter - elif len(run_ids) == 0: - return RestoreExperimentRunsResponse( - restored_run_ids=[], conflicted_run_ids=[] - ) - else: - run_id_filters = [ - RunIdFilter(ComparisonOperator.EQUAL_TO, run_id) - for run_id in run_ids - ] - filter = CompoundFilter(LogicalOperator.OR, run_id_filters) - payload = {"filter": FilterSchema().dump(filter)} - - return self._post( - endpoint, RestoreExperimentRunsResponseSchema(), json=payload - ) diff --git a/faculty/clients/experiment/client.py b/faculty/clients/experiment/client.py new file mode 100644 index 00000000..71e595e0 --- /dev/null +++ b/faculty/clients/experiment/client.py @@ -0,0 +1,570 @@ +# Copyright 2018-2019 Faculty Science Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from faculty.clients.base import BaseClient, Conflict + +from faculty.clients.experiment.models import ( + ComparisonOperator, + CompoundFilter, + DeleteExperimentRunsResponse, + DeletedAtFilter, + ExperimentIdFilter, + LifecycleStage, + ListExperimentRunsResponse, + LogicalOperator, + Metric, + Page, + Pagination, + RestoreExperimentRunsResponse, + RunIdFilter, + RunQuery, +) +from faculty.clients.experiment.schemas import ( + CreateRunSchema, + DeleteExperimentRunsResponseSchema, + ExperimentRunDataSchema, + ExperimentRunInfoSchema, + ExperimentRunSchema, + ExperimentSchema, + FilterSchema, + ListExperimentRunsResponseSchema, + MetricHistorySchema, + RestoreExperimentRunsResponseSchema, + RunQuerySchema, +) + + +class ExperimentNameConflict(Exception): + def __init__(self, name): + tpl = "An experiment with name '{}' already exists in that project" + message = tpl.format(name) + super(ExperimentNameConflict, self).__init__(message) + + +class ParamConflict(Exception): + def __init__(self, message, conflicting_params=None): + super(ParamConflict, self).__init__(message) + if conflicting_params is None: + self.conflicting_params = [] + else: + self.conflicting_params = conflicting_params + + +class ExperimentDeleted(Exception): + def __init__(self, message, experiment_id): + super(ExperimentDeleted, self).__init__(message) + self.experiment_id = experiment_id + + +class ExperimentClient(BaseClient): + + SERVICE_NAME = "atlas" + + def create( + self, project_id, name, description=None, artifact_location=None + ): + """Create an experiment. + + Parameters + ---------- + project_id : uuid.UUID + name : str + description : str, optional + artifact_location : str, optional + + Returns + ------- + Experiment + + Raises + ------ + ExperimentNameConflict + When an experiment of the provided name already exists in the + project. + """ + endpoint = "/project/{}/experiment".format(project_id) + payload = { + "name": name, + "description": description, + "artifactLocation": artifact_location, + } + try: + return self._post(endpoint, ExperimentSchema(), json=payload) + except Conflict as err: + if err.error_code == "experiment_name_conflict": + raise ExperimentNameConflict(name) + else: + raise + + def get(self, project_id, experiment_id): + """Get a specified experiment. + + Parameters + ---------- + project_id : uuid.UUID + experiment_id : int + + Returns + ------- + Experiment + """ + endpoint = "/project/{}/experiment/{}".format( + project_id, experiment_id + ) + return self._get(endpoint, ExperimentSchema()) + + def list(self, project_id, lifecycle_stage=None): + """List the experiments in a project. + + Parameters + ---------- + project_id : uuid.UUID + lifecycle_stage : LifecycleStage, optional + To filter experiments in the given lifecycle stage only + (ACTIVE | DELETED). By default, all experiments in the + project are returned. + + Returns + ------- + List[Experiment] + """ + query_params = {} + if lifecycle_stage is not None: + query_params["lifecycleStage"] = lifecycle_stage.value + endpoint = "/project/{}/experiment".format(project_id) + return self._get( + endpoint, ExperimentSchema(many=True), params=query_params + ) + + def update(self, project_id, experiment_id, name=None, description=None): + """Update the name and/or description of an experiment. + + Parameters + ---------- + project_id : uuid.UUID + experiment_id : int + name : str, optional + The new name of the experiment. If not provided, the name will not + be modified. + description : str, optional + The new description of the experiment. If not provided, the + description will not be modified. + + Raises + ------ + ExperimentNameConflict + When an experiment of the provided name already exists in the + project. + """ + endpoint = "/project/{}/experiment/{}".format( + project_id, experiment_id + ) + payload = {"name": name, "description": description} + try: + self._patch_raw(endpoint, json=payload) + except Conflict as err: + if err.error_code == "experiment_name_conflict": + raise ExperimentNameConflict(name) + else: + raise + + def delete(self, project_id, experiment_id): + """Delete a specified experiment. + + Parameters + ---------- + project_id : uuid.UUID + experiment_id : int + """ + endpoint = "/project/{}/experiment/{}".format( + project_id, experiment_id + ) + self._delete_raw(endpoint) + + def restore(self, project_id, experiment_id): + """Restore a specified experiment. + + Parameters + ---------- + project_id : uuid.UUID + experiment_id : int + """ + endpoint = "/project/{}/experiment/{}/restore".format( + project_id, experiment_id + ) + self._put_raw(endpoint) + + def create_run( + self, + project_id, + experiment_id, + name, + started_at, + parent_run_id=None, + artifact_location=None, + tags=None, + ): + """Create a run in a project. + + Parameters + ---------- + project_id : uuid.UUID + experiment_id : int + name : str + started_at : datetime.datetime + Time at which the run was started. If the datetime does not have a + timezone, it will be assumed to be in UTC. + parent_run_id : uuid.UUID, optional + The ID of the parent run, if any. + artifact_location: str, optional + The location of the artifact repository to use for this run. + If omitted, the value of `artifact_location` for the experiment + will be used. + tags: List[Tag] + + Returns + ------- + ExperimentRun + + Raises + ------ + ExperimentDeleted + When the run that is being updated refers to an experiment that is + deleted + """ + if tags is None: + tags = [] + + endpoint = "/project/{}/experiment/{}/run".format( + project_id, experiment_id + ) + payload = CreateRunSchema().dump( + { + "name": name, + "parent_run_id": parent_run_id, + "started_at": started_at, + "artifact_location": artifact_location, + "tags": tags, + } + ) + try: + return self._post(endpoint, ExperimentRunSchema(), json=payload) + except Conflict as err: + if err.error_code == "experiment_deleted": + raise ExperimentDeleted( + err.error, err.response.json()["experimentId"] + ) + else: + raise + + def get_run(self, project_id, run_id): + """Get a specified experiment run. + + Parameters + ---------- + project_id : uuid.UUID + run_id : uuid.UUID + + Returns + ------- + ExperimentRun + """ + endpoint = "/project/{}/run/{}".format(project_id, run_id) + return self._get(endpoint, ExperimentRunSchema()) + + def list_runs( + self, + project_id, + experiment_ids=None, + lifecycle_stage=None, + start=None, + limit=None, + ): + """List experiment runs. + + This method returns pages of runs. If less than the full number of runs + for the job is returned, the ``next`` page of the returned response + object will not be ``None``: + + >>> response = client.list_runs(project_id) + >>> response.pagination.next + Page(start=10, limit=10) + + Get all experiment runs by making successive calls to ``list_runs``, + passing the ``start`` and ``limit`` of the ``next`` page each time + until ``next`` is returned as ``None``. + + Parameters + ---------- + project_id : uuid.UUID + experiment_ids : List[int], optional + To filter runs of experiments with the given IDs only. If an empty + list is passed, a result with an empty list of runs is returned. + By default, runs from all experiments are returned. + lifecycle_stage: LifecycleStage, optional + To filter runs of experiments in a specific lifecycle stage only. + By default, runs in any stage are returned. + start : int, optional + The (zero-indexed) starting point of runs to retrieve. + limit : int, optional + The maximum number of runs to retrieve. + + Returns + ------- + ListExperimentRunsResponse + """ + + experiment_ids_filter = None + lifecycle_filter = None + filter = None + + if experiment_ids is not None: + if len(experiment_ids) == 0: + return ListExperimentRunsResponse( + runs=[], + pagination=Pagination( + start=0, size=0, previous=None, next=None + ), + ) + experiment_id_filters = [ + ExperimentIdFilter(ComparisonOperator.EQUAL_TO, experiment_id) + for experiment_id in experiment_ids + ] + experiment_ids_filter = CompoundFilter( + LogicalOperator.OR, experiment_id_filters + ) + if lifecycle_stage is not None: + lifecycle_filter = DeletedAtFilter( + ComparisonOperator.DEFINED, + lifecycle_stage == LifecycleStage.DELETED, + ) + + if experiment_ids_filter is not None and lifecycle_filter is not None: + filter = CompoundFilter( + LogicalOperator.AND, [experiment_ids_filter, lifecycle_filter] + ) + elif experiment_ids_filter is not None: + filter = experiment_ids_filter + elif lifecycle_filter is not None: + filter = lifecycle_filter + + return self.query_runs(project_id, filter, None, start, limit) + + def query_runs( + self, project_id, filter=None, sort=None, start=None, limit=None + ): + """Query experiment runs. + + This method returns pages of runs. If less than the full number of runs + for the job is returned, the ``next`` page of the returned response + object will not be ``None``: + + >>> response = client.query_runs(project_id) + >>> response.pagination.next + Page(start=10, limit=10) + + Get all experiment runs by making successive calls to ``query_runs``, + passing the ``start`` and ``limit`` of the ``next`` page each time + until ``next`` is returned as ``None``. + + Parameters + ---------- + project_id : uuid.UUID + filter: SingleFilter or CompoundFilter, optional + To filter runs of experiments with the given filter. By default, + runs from all experiments are returned. + sort: List[Sort], optional + Runs are order using the conditions in sort. The relative + importance of each condition gradually decreases in order. + By default, experiment runs are sorted by their startedAt value. + start : int, optional + The (zero-indexed) starting point of runs to retrieve. + limit : int, optional + The maximum number of runs to retrieve. + + Returns + ------- + ListExperimentRunsResponse + """ + endpoint = "/project/{}/run/query".format(project_id) + page = None + if start is not None and limit is not None: + page = Page(start, limit) + payload = RunQuerySchema().dump(RunQuery(filter, sort, page)) + return self._post( + endpoint, ListExperimentRunsResponseSchema(), json=payload + ) + + def log_run_data( + self, project_id, run_id, metrics=None, params=None, tags=None + ): + """Update the data of a run. + + Parameters + ---------- + project_id : uuid.UUID + run_id : uuid.UUID + metrics : List[Metric], optional + Each metric will be inserted. + params : List[Param], optional + Each param will be inserted. Note that on a name conflict the + entire operation will be rejected. + tags : List[Tag], optional + Each tag be upserted. + + Raises + ------ + ParamConflict + When a provided param already exists and has a different value than + was specified. + """ + if all(kwarg is None for kwarg in [metrics, params, tags]): + return + endpoint = "/project/{}/run/{}/data".format(project_id, run_id) + payload = ExperimentRunDataSchema().dump( + {"metrics": metrics, "params": params, "tags": tags} + ) + try: + self._patch_raw(endpoint, json=payload) + except Conflict as err: + if err.error_code == "conflicting_params": + raise ParamConflict( + err.error, err.response.json()["parameterKeys"] + ) + else: + raise + + def update_run_info(self, project_id, run_id, status=None, ended_at=None): + """Update the status and end time of a run. + + Parameters + ---------- + project_id : uuid.UUID + run_id : uuid.UUID + status: ExperimentRunStatus, optional + ended_at: datetime, optional + + Returns + ------- + ExperimentRun + """ + endpoint = "/project/{}/run/{}/info".format(project_id, run_id) + payload = ExperimentRunInfoSchema().dump( + {"status": status, "ended_at": ended_at} + ) + return self._patch(endpoint, ExperimentRunSchema(), json=payload) + + def get_metric_history(self, project_id, run_id, key): + """Get the history of a metric. + + Parameters + ---------- + project_id : uuid.UUID + run_id : uuid.UUID + key: string + + Returns + ------- + List[Metric], ordered by timestamp and value + """ + endpoint = "/project/{}/run/{}/metric/{}/history".format( + project_id, run_id, key + ) + metric_history = self._get(endpoint, MetricHistorySchema()) + return [ + Metric( + key=metric_history.key, + value=metric_data_point.value, + timestamp=metric_data_point.timestamp, + step=metric_data_point.step, + ) + for metric_data_point in metric_history.history + ] + + def delete_runs(self, project_id, run_ids=None): + """Delete experiment runs. + + Parameters + ---------- + project_id : uuid.UUID + run_ids : List[uuid.UUID], optional + A list of run IDs to delete. If not specified, all runs in the + project will be deleted. If an empty list is passed, no runs + will be deleted. + + Returns + ------- + DeleteExperimentRunsResponse + Containing lists of successfully deleted and conflicting (already + deleted) run IDs. + """ + endpoint = "/project/{}/run/delete/query".format(project_id) + + if run_ids is None: + # Delete all runs in project + payload = {} # No filter + elif len(run_ids) == 0: + return DeleteExperimentRunsResponse( + deleted_run_ids=[], conflicted_run_ids=[] + ) + else: + run_id_filters = [ + RunIdFilter(ComparisonOperator.EQUAL_TO, run_id) + for run_id in run_ids + ] + filter = CompoundFilter(LogicalOperator.OR, run_id_filters) + payload = {"filter": FilterSchema().dump(filter)} + + return self._post( + endpoint, DeleteExperimentRunsResponseSchema(), json=payload + ) + + def restore_runs(self, project_id, run_ids=None): + """Restore experiment runs. + + Parameters + ---------- + project_id : uuid.UUID + run_ids : List[uuid.UUID], optional + A list of run IDs to restore. If not specified, all runs in the + project will be restored. If an empty list is passed, no runs + will be restored. + + Returns + ------- + RestoreExperimentRunsResponse + Containing lists of successfully restored and conflicting (already + active) run IDs. + """ + endpoint = "/project/{}/run/restore/query".format(project_id) + + if run_ids is None: + # Restore all runs in project + payload = {} # No filter + elif len(run_ids) == 0: + return RestoreExperimentRunsResponse( + restored_run_ids=[], conflicted_run_ids=[] + ) + else: + run_id_filters = [ + RunIdFilter(ComparisonOperator.EQUAL_TO, run_id) + for run_id in run_ids + ] + filter = CompoundFilter(LogicalOperator.OR, run_id_filters) + payload = {"filter": FilterSchema().dump(filter)} + + return self._post( + endpoint, RestoreExperimentRunsResponseSchema(), json=payload + ) diff --git a/tests/clients/experiment/test_init.py b/tests/clients/experiment/test_client.py similarity index 92% rename from tests/clients/experiment/test_init.py rename to tests/clients/experiment/test_client.py index ff3e8d0f..7fba8f28 100644 --- a/tests/clients/experiment/test_init.py +++ b/tests/clients/experiment/test_client.py @@ -18,7 +18,7 @@ import pytest from faculty.clients.base import Conflict -from faculty.clients.experiment import ( +from faculty.clients.experiment.client import ( ExperimentClient, ExperimentDeleted, ExperimentNameConflict, @@ -49,7 +49,9 @@ def test_experiment_client_create(mocker, description, artifact_location): experiment = mocker.Mock() mocker.patch.object(ExperimentClient, "_post", return_value=experiment) - schema_mock = mocker.patch("faculty.clients.experiment.ExperimentSchema") + schema_mock = mocker.patch( + "faculty.clients.experiment.client.ExperimentSchema" + ) client = ExperimentClient(mocker.Mock()) returned_experiment = client.create( @@ -84,7 +86,9 @@ def test_experiment_client_create_name_conflict(mocker): def test_experiment_client_get(mocker): experiment = mocker.Mock() mocker.patch.object(ExperimentClient, "_get", return_value=experiment) - schema_mock = mocker.patch("faculty.clients.experiment.ExperimentSchema") + schema_mock = mocker.patch( + "faculty.clients.experiment.client.ExperimentSchema" + ) client = ExperimentClient(mocker.Mock()) returned_experiment = client.get(PROJECT_ID, EXPERIMENT_ID) @@ -100,7 +104,9 @@ def test_experiment_client_get(mocker): def test_experiment_client_list(mocker): experiment = mocker.Mock() mocker.patch.object(ExperimentClient, "_get", return_value=[experiment]) - schema_mock = mocker.patch("faculty.clients.experiment.ExperimentSchema") + schema_mock = mocker.patch( + "faculty.clients.experiment.client.ExperimentSchema" + ) client = ExperimentClient(mocker.Mock()) assert client.list(PROJECT_ID) == [experiment] @@ -116,7 +122,9 @@ def test_experiment_client_list(mocker): def test_experiment_client_list_lifecycle_filter(mocker): experiment = mocker.Mock() mocker.patch.object(ExperimentClient, "_get", return_value=[experiment]) - schema_mock = mocker.patch("faculty.clients.experiment.ExperimentSchema") + schema_mock = mocker.patch( + "faculty.clients.experiment.client.ExperimentSchema" + ) client = ExperimentClient(mocker.Mock()) returned_experiments = client.list( @@ -186,11 +194,11 @@ def test_experiment_create_run(mocker): run = mocker.Mock() mocker.patch.object(ExperimentClient, "_post", return_value=run) request_schema_mock = mocker.patch( - "faculty.clients.experiment.CreateRunSchema" + "faculty.clients.experiment.client.CreateRunSchema" ) dump_mock = request_schema_mock.return_value.dump response_schema_mock = mocker.patch( - "faculty.clients.experiment.ExperimentRunSchema" + "faculty.clients.experiment.client.ExperimentRunSchema" ) run_name = mocker.Mock() started_at = mocker.Mock() @@ -250,7 +258,7 @@ def test_experiment_client_get_run(mocker): run = mocker.Mock() mocker.patch.object(ExperimentClient, "_get", return_value=run) schema_mock = mocker.patch( - "faculty.clients.experiment.ExperimentRunSchema" + "faculty.clients.experiment.client.ExperimentRunSchema" ) client = ExperimentClient(mocker.Mock()) @@ -311,10 +319,10 @@ def test_experiment_client_query_runs(mocker): list_response = mocker.Mock() mocker.patch.object(ExperimentClient, "_post", return_value=list_response) response_schema_mock = mocker.patch( - "faculty.clients.experiment.ListExperimentRunsResponseSchema" + "faculty.clients.experiment.client.ListExperimentRunsResponseSchema" ) request_schema_mock = mocker.patch( - "faculty.clients.experiment.RunQuerySchema" + "faculty.clients.experiment.client.RunQuerySchema" ) request_dump_mock = request_schema_mock.return_value.dump @@ -342,7 +350,7 @@ def test_experiment_client_query_runs(mocker): def test_log_run_data(mocker): mocker.patch.object(ExperimentClient, "_patch_raw") run_data_schema_mock = mocker.patch( - "faculty.clients.experiment.ExperimentRunDataSchema" + "faculty.clients.experiment.client.ExperimentRunDataSchema" ) run_data_dump_mock = run_data_schema_mock.return_value.dump @@ -412,10 +420,10 @@ def test_update_run_info(mocker): run = mocker.Mock() mocker.patch.object(ExperimentClient, "_patch", return_value=run) run_schema_mock = mocker.patch( - "faculty.clients.experiment.ExperimentRunSchema" + "faculty.clients.experiment.client.ExperimentRunSchema" ) run_info_schema_mock = mocker.patch( - "faculty.clients.experiment.ExperimentRunInfoSchema" + "faculty.clients.experiment.client.ExperimentRunInfoSchema" ) run_info_dump_mock = run_info_schema_mock.return_value.dump @@ -448,7 +456,7 @@ def test_get_metric_history(mocker): mocker.patch.object(ExperimentClient, "_get", return_value=metric_history) metric_history_schema_mock = mocker.patch( - "faculty.clients.experiment.MetricHistorySchema" + "faculty.clients.experiment.client.MetricHistorySchema" ) client = ExperimentClient(mocker.Mock()) @@ -487,10 +495,10 @@ def test_delete_runs(mocker): ExperimentClient, "_post", return_value=delete_runs_response ) response_schema_mock = mocker.patch( - "faculty.clients.experiment.DeleteExperimentRunsResponseSchema" + "faculty.clients.experiment.client.DeleteExperimentRunsResponseSchema" ) filter_schema_mock = mocker.patch( - "faculty.clients.experiment.FilterSchema" + "faculty.clients.experiment.client.FilterSchema" ) filter_dump_mock = filter_schema_mock.return_value.dump @@ -519,7 +527,7 @@ def test_delete_runs(mocker): def test_delete_runs_no_run_ids(mocker): mocker.patch.object(ExperimentClient, "_post") schema_mock = mocker.patch( - "faculty.clients.experiment.DeleteExperimentRunsResponseSchema" + "faculty.clients.experiment.client.DeleteExperimentRunsResponseSchema" ) client = ExperimentClient(mocker.Mock()) @@ -549,10 +557,10 @@ def test_restore_runs(mocker): ExperimentClient, "_post", return_value=restore_runs_response ) response_schema_mock = mocker.patch( - "faculty.clients.experiment.RestoreExperimentRunsResponseSchema" + "faculty.clients.experiment.client.RestoreExperimentRunsResponseSchema" ) filter_schema_mock = mocker.patch( - "faculty.clients.experiment.FilterSchema" + "faculty.clients.experiment.client.FilterSchema" ) filter_dump_mock = filter_schema_mock.return_value.dump @@ -581,7 +589,7 @@ def test_restore_runs(mocker): def test_restore_runs_no_run_ids(mocker): mocker.patch.object(ExperimentClient, "_post") schema_mock = mocker.patch( - "faculty.clients.experiment.RestoreExperimentRunsResponseSchema" + "faculty.clients.experiment.client.RestoreExperimentRunsResponseSchema" ) client = ExperimentClient(mocker.Mock()) From 165ec4e2b30326000fd350bbad016536856cdb89 Mon Sep 17 00:00:00 2001 From: Andrew Crozier Date: Wed, 12 Jun 2019 13:35:34 +0100 Subject: [PATCH 53/60] Restructure schemas module --- faculty/clients/experiment/schemas.py | 220 +++++++++++++------------- 1 file changed, 113 insertions(+), 107 deletions(-) diff --git a/faculty/clients/experiment/schemas.py b/faculty/clients/experiment/schemas.py index 7f0ab1ea..afd7de67 100644 --- a/faculty/clients/experiment/schemas.py +++ b/faculty/clients/experiment/schemas.py @@ -38,6 +38,33 @@ ) +class _OptionalField(fields.Field): + """Wrap another field, passing through Nones.""" + + def __init__(self, nested, *args, **kwargs): + self.nested = nested + super().__init__(*args, **kwargs) + + def _deserialize(self, value, *args, **kwargs): + if value is None: + return None + else: + return self.nested._deserialize(value, *args, **kwargs) + + def _serialize(self, value, *args, **kwargs): + if value is None: + return None + else: + return self.nested._serialize(value, *args, **kwargs) + + +class _OneOfSchemaWithoutType(OneOfSchema): + def dump(self, *args, **kwargs): + data = super(_OneOfSchemaWithoutType, self).dump(*args, **kwargs) + # Remove the type field added by marshmallow-oneofschema + return {k: v for k, v in data.items() if k != "type"} + + class PageSchema(BaseSchema): start = fields.Integer(required=True) limit = fields.Integer(required=True) @@ -47,6 +74,17 @@ def make_page(self, data): return Page(**data) +class PaginationSchema(BaseSchema): + start = fields.Integer(required=True) + size = fields.Integer(required=True) + previous = fields.Nested(PageSchema, missing=None) + next = fields.Nested(PageSchema, missing=None) + + @post_load + def make_pagination(self, data): + return Pagination(**data) + + class MetricSchema(BaseSchema): key = fields.String(required=True) value = fields.Float(required=True) @@ -114,6 +152,9 @@ def make_experiment_run(self, data): return ExperimentRun(**data) +# Schemas for payloads sent to API: + + class ExperimentRunDataSchema(BaseSchema): metrics = fields.List(fields.Nested(MetricSchema)) params = fields.List(fields.Nested(ParamSchema)) @@ -125,17 +166,6 @@ class ExperimentRunInfoSchema(BaseSchema): ended_at = fields.DateTime(data_key="endedAt", missing=None) -class PaginationSchema(BaseSchema): - start = fields.Integer(required=True) - size = fields.Integer(required=True) - previous = fields.Nested(PageSchema, missing=None) - next = fields.Nested(PageSchema, missing=None) - - @post_load - def make_pagination(self, data): - return Pagination(**data) - - class ListExperimentRunsResponseSchema(BaseSchema): pagination = fields.Nested(PaginationSchema, required=True) runs = fields.Nested(ExperimentRunSchema, many=True, required=True) @@ -153,33 +183,7 @@ class CreateRunSchema(BaseSchema): tags = fields.Nested(TagSchema, many=True, required=True) -class DeleteExperimentRunsResponseSchema(BaseSchema): - deleted_run_ids = fields.List( - fields.UUID(), data_key="deletedRunIds", required=True - ) - conflicted_run_ids = fields.List( - fields.UUID(), data_key="conflictedRunIds", required=True - ) - - @post_load - def make_delete_runs_response(self, data): - return DeleteExperimentRunsResponse(**data) - - -class RestoreExperimentRunsResponseSchema(BaseSchema): - restored_run_ids = fields.List( - fields.UUID(), data_key="restoredRunIds", required=True - ) - conflicted_run_ids = fields.List( - fields.UUID(), data_key="conflictedRunIds", required=True - ) - - @post_load - def make_restore_runs_response(self, data): - return RestoreExperimentRunsResponse(**data) - - -class ParamFilterValueField(fields.Field): +class _ParamFilterValueField(fields.Field): """Field that passes through strings or numbers.""" default_error_messages = { @@ -196,30 +200,10 @@ def _serialize(self, value, attr, obj, **kwargs): return field._serialize(value, attr, obj, **kwargs) -class OptionalField(fields.Field): - """Wrap another field, passing through Nones.""" - - def __init__(self, nested, *args, **kwargs): - self.nested = nested - super().__init__(*args, **kwargs) - - def _deserialize(self, value, *args, **kwargs): - if value is None: - return None - else: - return self.nested._deserialize(value, *args, **kwargs) - - def _serialize(self, value, *args, **kwargs): - if value is None: - return None - else: - return self.nested._serialize(value, *args, **kwargs) - - -class FilterValueField(fields.Field): +class _FilterValueField(fields.Field): def __init__(self, other_field_type, *args, **kwargs): self.other_field_type = other_field_type - super(FilterValueField, self).__init__(*args, **kwargs) + super(_FilterValueField, self).__init__(*args, **kwargs) def _serialize(self, value, attr, obj, **kwargs): if obj.operator == ComparisonOperator.DEFINED: @@ -238,9 +222,9 @@ def _validate_discrete(operator): raise ValidationError({"operator": "Not a discrete operator."}) -class ProjectIdFilterSchema(BaseSchema): +class _ProjectIdFilterSchema(BaseSchema): operator = EnumField(ComparisonOperator, by_value=True) - value = FilterValueField(fields.UUID) + value = _FilterValueField(fields.UUID) by = fields.Constant("projectId", dump_only=True) @pre_dump @@ -249,9 +233,9 @@ def check_operator(self, obj): return obj -class ExperimentIdFilterSchema(BaseSchema): +class _ExperimentIdFilterSchema(BaseSchema): operator = EnumField(ComparisonOperator, by_value=True) - value = FilterValueField(fields.Integer) + value = _FilterValueField(fields.Integer) by = fields.Constant("experimentId", dump_only=True) @pre_dump @@ -260,9 +244,9 @@ def check_operator(self, obj): return obj -class RunIdFilterSchema(BaseSchema): +class _RunIdFilterSchema(BaseSchema): operator = EnumField(ComparisonOperator, by_value=True) - value = FilterValueField(fields.UUID) + value = _FilterValueField(fields.UUID) by = fields.Constant("runId", dump_only=True) @pre_dump @@ -271,16 +255,16 @@ def check_operator(self, obj): return obj -class DeletedAtFilterSchema(BaseSchema): +class _DeletedAtFilterSchema(BaseSchema): operator = EnumField(ComparisonOperator, by_value=True) - value = FilterValueField(fields.DateTime) + value = _FilterValueField(fields.DateTime) by = fields.Constant("deletedAt", dump_only=True) -class TagFilterSchema(BaseSchema): +class _TagFilterSchema(BaseSchema): key = fields.String() operator = EnumField(ComparisonOperator, by_value=True) - value = FilterValueField(fields.String) + value = _FilterValueField(fields.String) by = fields.Constant("tag", dump_only=True) @pre_dump @@ -289,10 +273,10 @@ def check_operator(self, obj): return obj -class ParamFilterSchema(BaseSchema): +class _ParamFilterSchema(BaseSchema): key = fields.String() operator = EnumField(ComparisonOperator, by_value=True) - value = FilterValueField(ParamFilterValueField) + value = _FilterValueField(_ParamFilterValueField) by = fields.Constant("param", dump_only=True) @pre_dump @@ -302,89 +286,111 @@ def check_operator(self, obj): return obj -class MetricFilterSchema(BaseSchema): +class _MetricFilterSchema(BaseSchema): key = fields.String() operator = EnumField(ComparisonOperator, by_value=True) - value = FilterValueField(fields.Float) + value = _FilterValueField(fields.Float) by = fields.Constant("metric", dump_only=True) -class CompoundFilterSchema(BaseSchema): +class _CompoundFilterSchema(BaseSchema): operator = EnumField(LogicalOperator, by_value=True) conditions = fields.List(fields.Nested("FilterSchema")) -class OneOfSchemaWithoutType(OneOfSchema): - def dump(self, *args, **kwargs): - data = super(OneOfSchemaWithoutType, self).dump(*args, **kwargs) - # Remove the type field added by marshmallow-oneofschema - return {k: v for k, v in data.items() if k != "type"} - - -class FilterSchema(OneOfSchemaWithoutType): +class FilterSchema(_OneOfSchemaWithoutType): type_schemas = { - "ProjectIdFilter": ProjectIdFilterSchema, - "ExperimentIdFilter": ExperimentIdFilterSchema, - "RunIdFilter": RunIdFilterSchema, - "DeletedAtFilter": DeletedAtFilterSchema, - "TagFilter": TagFilterSchema, - "ParamFilter": ParamFilterSchema, - "MetricFilter": MetricFilterSchema, - "CompoundFilter": CompoundFilterSchema, + "ProjectIdFilter": _ProjectIdFilterSchema, + "ExperimentIdFilter": _ExperimentIdFilterSchema, + "RunIdFilter": _RunIdFilterSchema, + "DeletedAtFilter": _DeletedAtFilterSchema, + "TagFilter": _TagFilterSchema, + "ParamFilter": _ParamFilterSchema, + "MetricFilter": _MetricFilterSchema, + "CompoundFilter": _CompoundFilterSchema, } -class StartedAtSortSchema(BaseSchema): +class _StartedAtSortSchema(BaseSchema): order = EnumField(SortOrder, by_value=True) by = fields.Constant("startedAt", dump_only=True) -class RunNumberSortSchema(BaseSchema): +class _RunNumberSortSchema(BaseSchema): order = EnumField(SortOrder, by_value=True) by = fields.Constant("runNumber", dump_only=True) -class DurationSortSchema(BaseSchema): +class _DurationSortSchema(BaseSchema): order = EnumField(SortOrder, by_value=True) by = fields.Constant("duration", dump_only=True) -class TagSortSchema(BaseSchema): +class _TagSortSchema(BaseSchema): key = fields.String() order = EnumField(SortOrder, by_value=True) by = fields.Constant("tag", dump_only=True) -class ParamSortSchema(BaseSchema): +class _ParamSortSchema(BaseSchema): key = fields.String() order = EnumField(SortOrder, by_value=True) by = fields.Constant("param", dump_only=True) -class MetricSortSchema(BaseSchema): +class _MetricSortSchema(BaseSchema): key = fields.String() order = EnumField(SortOrder, by_value=True) by = fields.Constant("metric", dump_only=True) -class SortSchema(OneOfSchemaWithoutType): +class SortSchema(_OneOfSchemaWithoutType): type_schemas = { - "StartedAtSort": StartedAtSortSchema, - "RunNumberSort": RunNumberSortSchema, - "DurationSort": DurationSortSchema, - "TagSort": TagSortSchema, - "ParamSort": ParamSortSchema, - "MetricSort": MetricSortSchema, + "StartedAtSort": _StartedAtSortSchema, + "RunNumberSort": _RunNumberSortSchema, + "DurationSort": _DurationSortSchema, + "TagSort": _TagSortSchema, + "ParamSort": _ParamSortSchema, + "MetricSort": _MetricSortSchema, } class RunQuerySchema(BaseSchema): - filter = OptionalField(fields.Nested(FilterSchema)) + filter = _OptionalField(fields.Nested(FilterSchema)) sort = fields.List(fields.Nested(SortSchema)) page = fields.Nested(PageSchema, missing=None) -class MetricDataPointSchema(BaseSchema): +# Schemas for responses returned from API: + + +class DeleteExperimentRunsResponseSchema(BaseSchema): + deleted_run_ids = fields.List( + fields.UUID(), data_key="deletedRunIds", required=True + ) + conflicted_run_ids = fields.List( + fields.UUID(), data_key="conflictedRunIds", required=True + ) + + @post_load + def make_delete_runs_response(self, data): + return DeleteExperimentRunsResponse(**data) + + +class RestoreExperimentRunsResponseSchema(BaseSchema): + restored_run_ids = fields.List( + fields.UUID(), data_key="restoredRunIds", required=True + ) + conflicted_run_ids = fields.List( + fields.UUID(), data_key="conflictedRunIds", required=True + ) + + @post_load + def make_restore_runs_response(self, data): + return RestoreExperimentRunsResponse(**data) + + +class _MetricDataPointSchema(BaseSchema): """Deserialise a data point from the metric history endpoint. This schema is written with the expectation that it is not used alongside @@ -405,7 +411,7 @@ class MetricHistorySchema(BaseSchema): original_size = fields.Integer(data_key="originalSize", required=True) subsampled = fields.Boolean(required=True) key = fields.String(required=True) - history = fields.Nested(MetricDataPointSchema, many=True, required=True) + history = fields.Nested(_MetricDataPointSchema, many=True, required=True) @post_load def make_history(self, data): From 6f429a96951311ecc6938ec867a098e5d66a4d5a Mon Sep 17 00:00:00 2001 From: Andrew Crozier Date: Wed, 12 Jun 2019 13:45:17 +0100 Subject: [PATCH 54/60] Sort models and expose at experiment package level --- faculty/clients/experiment/__init__.py | 25 ++++++++++ faculty/clients/experiment/models.py | 69 +++++++++++++------------- 2 files changed, 60 insertions(+), 34 deletions(-) diff --git a/faculty/clients/experiment/__init__.py b/faculty/clients/experiment/__init__.py index 3e322a81..d231447d 100644 --- a/faculty/clients/experiment/__init__.py +++ b/faculty/clients/experiment/__init__.py @@ -13,6 +13,31 @@ # limitations under the License. +from faculty.clients.experiment.models import ( # noqa: F401 + ComparisonOperator, + CompoundFilter, + DeletedAtFilter, + DurationSort, + Experiment, + ExperimentIdFilter, + ExperimentRun, + ExperimentRunStatus, + LifecycleStage, + LogicalOperator, + Metric, + MetricFilter, + MetricSort, + Param, + ParamFilter, + ParamSort, + ProjectIdFilter, + RunIdFilter, + RunNumberSort, + StartedAtSort, + Tag, + TagFilter, + TagSort, +) from faculty.clients.experiment.client import ( # noqa: F401 ExperimentClient, ExperimentDeleted, diff --git a/faculty/clients/experiment/models.py b/faculty/clients/experiment/models.py index 00772458..00b6fecc 100644 --- a/faculty/clients/experiment/models.py +++ b/faculty/clients/experiment/models.py @@ -17,6 +17,11 @@ from enum import Enum +class LifecycleStage(Enum): + ACTIVE = "active" + DELETED = "deleted" + + class ExperimentRunStatus(Enum): RUNNING = "running" FINISHED = "finished" @@ -25,6 +30,13 @@ class ExperimentRunStatus(Enum): KILLED = "killed" +Page = namedtuple("Page", ["start", "limit"]) +Pagination = namedtuple("Pagination", ["start", "size", "previous", "next"]) + +Metric = namedtuple("Metric", ["key", "value", "timestamp", "step"]) +Param = namedtuple("Param", ["key", "value"]) +Tag = namedtuple("Tag", ["key", "value"]) + Experiment = namedtuple( "Experiment", [ @@ -59,28 +71,6 @@ class ExperimentRunStatus(Enum): ) -class LifecycleStage(Enum): - ACTIVE = "active" - DELETED = "deleted" - - -Metric = namedtuple("Metric", ["key", "value", "timestamp", "step"]) -Param = namedtuple("Param", ["key", "value"]) -Tag = namedtuple("Tag", ["key", "value"]) - -Page = namedtuple("Page", ["start", "limit"]) -Pagination = namedtuple("Pagination", ["start", "size", "previous", "next"]) -ListExperimentRunsResponse = namedtuple( - "ListExperimentRunsResponse", ["runs", "pagination"] -) -DeleteExperimentRunsResponse = namedtuple( - "DeleteExperimentRunsResponse", ["deleted_run_ids", "conflicted_run_ids"] -) -RestoreExperimentRunsResponse = namedtuple( - "RestoreExperimentRunsResponse", ["restored_run_ids", "conflicted_run_ids"] -) - - class ComparisonOperator(Enum): DEFINED = "defined" EQUAL_TO = "eq" @@ -91,11 +81,6 @@ class ComparisonOperator(Enum): GREATER_THAN_OR_EQUAL_TO = "ge" -class LogicalOperator(Enum): - AND = "and" - OR = "or" - - ProjectIdFilter = namedtuple("ProjectIdFilter", ["operator", "value"]) ExperimentIdFilter = namedtuple("ExperimentIdFilter", ["operator", "value"]) RunIdFilter = namedtuple("RunIdFilter", ["operator", "value"]) @@ -104,8 +89,20 @@ class LogicalOperator(Enum): ParamFilter = namedtuple("ParamFilter", ["key", "operator", "value"]) MetricFilter = namedtuple("MetricFilter", ["key", "operator", "value"]) + +class LogicalOperator(Enum): + AND = "and" + OR = "or" + + CompoundFilter = namedtuple("CompoundFilter", ["operator", "conditions"]) + +class SortOrder(Enum): + ASC = "asc" + DESC = "desc" + + StartedAtSort = namedtuple("StartedAtSort", ["order"]) RunNumberSort = namedtuple("RunNumberSort", ["order"]) DurationSort = namedtuple("DurationSort", ["order"]) @@ -113,15 +110,19 @@ class LogicalOperator(Enum): ParamSort = namedtuple("ParamSort", ["key", "order"]) MetricSort = namedtuple("MetricSort", ["key", "order"]) - -class SortOrder(Enum): - ASC = "asc" - DESC = "desc" - - RunQuery = namedtuple("RunQuery", ["filter", "sort", "page"]) -MetricDataPoint = namedtuple("Metric", ["value", "timestamp", "step"]) +MetricDataPoint = namedtuple("Metric", ["value", "timestamp", "step"]) MetricHistory = namedtuple( "MetricHistory", ["original_size", "subsampled", "key", "history"] ) + +ListExperimentRunsResponse = namedtuple( + "ListExperimentRunsResponse", ["runs", "pagination"] +) +DeleteExperimentRunsResponse = namedtuple( + "DeleteExperimentRunsResponse", ["deleted_run_ids", "conflicted_run_ids"] +) +RestoreExperimentRunsResponse = namedtuple( + "RestoreExperimentRunsResponse", ["restored_run_ids", "conflicted_run_ids"] +) From 440e05e7dbbc05f1480f9a7e15c8e85b37e6bf54 Mon Sep 17 00:00:00 2001 From: Andrew Crozier Date: Wed, 12 Jun 2019 13:49:45 +0100 Subject: [PATCH 55/60] Make experiment submodules package-private --- faculty/clients/experiment/__init__.py | 4 +- .../experiment/{client.py => _client.py} | 4 +- .../experiment/{models.py => _models.py} | 0 .../experiment/{schemas.py => _schemas.py} | 2 +- tests/clients/experiment/test_client.py | 42 +++++++++---------- tests/clients/experiment/test_schemas.py | 4 +- 6 files changed, 28 insertions(+), 28 deletions(-) rename faculty/clients/experiment/{client.py => _client.py} (99%) rename faculty/clients/experiment/{models.py => _models.py} (100%) rename faculty/clients/experiment/{schemas.py => _schemas.py} (99%) diff --git a/faculty/clients/experiment/__init__.py b/faculty/clients/experiment/__init__.py index d231447d..62646f38 100644 --- a/faculty/clients/experiment/__init__.py +++ b/faculty/clients/experiment/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. -from faculty.clients.experiment.models import ( # noqa: F401 +from faculty.clients.experiment._models import ( # noqa: F401 ComparisonOperator, CompoundFilter, DeletedAtFilter, @@ -38,7 +38,7 @@ TagFilter, TagSort, ) -from faculty.clients.experiment.client import ( # noqa: F401 +from faculty.clients.experiment._client import ( # noqa: F401 ExperimentClient, ExperimentDeleted, ExperimentNameConflict, diff --git a/faculty/clients/experiment/client.py b/faculty/clients/experiment/_client.py similarity index 99% rename from faculty/clients/experiment/client.py rename to faculty/clients/experiment/_client.py index 71e595e0..8689c41a 100644 --- a/faculty/clients/experiment/client.py +++ b/faculty/clients/experiment/_client.py @@ -15,7 +15,7 @@ from faculty.clients.base import BaseClient, Conflict -from faculty.clients.experiment.models import ( +from faculty.clients.experiment._models import ( ComparisonOperator, CompoundFilter, DeleteExperimentRunsResponse, @@ -31,7 +31,7 @@ RunIdFilter, RunQuery, ) -from faculty.clients.experiment.schemas import ( +from faculty.clients.experiment._schemas import ( CreateRunSchema, DeleteExperimentRunsResponseSchema, ExperimentRunDataSchema, diff --git a/faculty/clients/experiment/models.py b/faculty/clients/experiment/_models.py similarity index 100% rename from faculty/clients/experiment/models.py rename to faculty/clients/experiment/_models.py diff --git a/faculty/clients/experiment/schemas.py b/faculty/clients/experiment/_schemas.py similarity index 99% rename from faculty/clients/experiment/schemas.py rename to faculty/clients/experiment/_schemas.py index afd7de67..a0a9522b 100644 --- a/faculty/clients/experiment/schemas.py +++ b/faculty/clients/experiment/_schemas.py @@ -18,7 +18,7 @@ from marshmallow_oneofschema import OneOfSchema from faculty.clients.base import BaseSchema -from faculty.clients.experiment.models import ( +from faculty.clients.experiment._models import ( ComparisonOperator, DeleteExperimentRunsResponse, Experiment, diff --git a/tests/clients/experiment/test_client.py b/tests/clients/experiment/test_client.py index 7fba8f28..d54e55f2 100644 --- a/tests/clients/experiment/test_client.py +++ b/tests/clients/experiment/test_client.py @@ -18,13 +18,13 @@ import pytest from faculty.clients.base import Conflict -from faculty.clients.experiment.client import ( +from faculty.clients.experiment._client import ( ExperimentClient, ExperimentDeleted, ExperimentNameConflict, ParamConflict, ) -from faculty.clients.experiment.models import ( +from faculty.clients.experiment._models import ( ComparisonOperator, CompoundFilter, DeletedAtFilter, @@ -50,7 +50,7 @@ def test_experiment_client_create(mocker, description, artifact_location): experiment = mocker.Mock() mocker.patch.object(ExperimentClient, "_post", return_value=experiment) schema_mock = mocker.patch( - "faculty.clients.experiment.client.ExperimentSchema" + "faculty.clients.experiment._client.ExperimentSchema" ) client = ExperimentClient(mocker.Mock()) @@ -87,7 +87,7 @@ def test_experiment_client_get(mocker): experiment = mocker.Mock() mocker.patch.object(ExperimentClient, "_get", return_value=experiment) schema_mock = mocker.patch( - "faculty.clients.experiment.client.ExperimentSchema" + "faculty.clients.experiment._client.ExperimentSchema" ) client = ExperimentClient(mocker.Mock()) @@ -105,7 +105,7 @@ def test_experiment_client_list(mocker): experiment = mocker.Mock() mocker.patch.object(ExperimentClient, "_get", return_value=[experiment]) schema_mock = mocker.patch( - "faculty.clients.experiment.client.ExperimentSchema" + "faculty.clients.experiment._client.ExperimentSchema" ) client = ExperimentClient(mocker.Mock()) @@ -123,7 +123,7 @@ def test_experiment_client_list_lifecycle_filter(mocker): experiment = mocker.Mock() mocker.patch.object(ExperimentClient, "_get", return_value=[experiment]) schema_mock = mocker.patch( - "faculty.clients.experiment.client.ExperimentSchema" + "faculty.clients.experiment._client.ExperimentSchema" ) client = ExperimentClient(mocker.Mock()) @@ -194,11 +194,11 @@ def test_experiment_create_run(mocker): run = mocker.Mock() mocker.patch.object(ExperimentClient, "_post", return_value=run) request_schema_mock = mocker.patch( - "faculty.clients.experiment.client.CreateRunSchema" + "faculty.clients.experiment._client.CreateRunSchema" ) dump_mock = request_schema_mock.return_value.dump response_schema_mock = mocker.patch( - "faculty.clients.experiment.client.ExperimentRunSchema" + "faculty.clients.experiment._client.ExperimentRunSchema" ) run_name = mocker.Mock() started_at = mocker.Mock() @@ -258,7 +258,7 @@ def test_experiment_client_get_run(mocker): run = mocker.Mock() mocker.patch.object(ExperimentClient, "_get", return_value=run) schema_mock = mocker.patch( - "faculty.clients.experiment.client.ExperimentRunSchema" + "faculty.clients.experiment._client.ExperimentRunSchema" ) client = ExperimentClient(mocker.Mock()) @@ -319,10 +319,10 @@ def test_experiment_client_query_runs(mocker): list_response = mocker.Mock() mocker.patch.object(ExperimentClient, "_post", return_value=list_response) response_schema_mock = mocker.patch( - "faculty.clients.experiment.client.ListExperimentRunsResponseSchema" + "faculty.clients.experiment._client.ListExperimentRunsResponseSchema" ) request_schema_mock = mocker.patch( - "faculty.clients.experiment.client.RunQuerySchema" + "faculty.clients.experiment._client.RunQuerySchema" ) request_dump_mock = request_schema_mock.return_value.dump @@ -350,7 +350,7 @@ def test_experiment_client_query_runs(mocker): def test_log_run_data(mocker): mocker.patch.object(ExperimentClient, "_patch_raw") run_data_schema_mock = mocker.patch( - "faculty.clients.experiment.client.ExperimentRunDataSchema" + "faculty.clients.experiment._client.ExperimentRunDataSchema" ) run_data_dump_mock = run_data_schema_mock.return_value.dump @@ -420,10 +420,10 @@ def test_update_run_info(mocker): run = mocker.Mock() mocker.patch.object(ExperimentClient, "_patch", return_value=run) run_schema_mock = mocker.patch( - "faculty.clients.experiment.client.ExperimentRunSchema" + "faculty.clients.experiment._client.ExperimentRunSchema" ) run_info_schema_mock = mocker.patch( - "faculty.clients.experiment.client.ExperimentRunInfoSchema" + "faculty.clients.experiment._client.ExperimentRunInfoSchema" ) run_info_dump_mock = run_info_schema_mock.return_value.dump @@ -456,7 +456,7 @@ def test_get_metric_history(mocker): mocker.patch.object(ExperimentClient, "_get", return_value=metric_history) metric_history_schema_mock = mocker.patch( - "faculty.clients.experiment.client.MetricHistorySchema" + "faculty.clients.experiment._client.MetricHistorySchema" ) client = ExperimentClient(mocker.Mock()) @@ -495,10 +495,10 @@ def test_delete_runs(mocker): ExperimentClient, "_post", return_value=delete_runs_response ) response_schema_mock = mocker.patch( - "faculty.clients.experiment.client.DeleteExperimentRunsResponseSchema" + "faculty.clients.experiment._client.DeleteExperimentRunsResponseSchema" ) filter_schema_mock = mocker.patch( - "faculty.clients.experiment.client.FilterSchema" + "faculty.clients.experiment._client.FilterSchema" ) filter_dump_mock = filter_schema_mock.return_value.dump @@ -527,7 +527,7 @@ def test_delete_runs(mocker): def test_delete_runs_no_run_ids(mocker): mocker.patch.object(ExperimentClient, "_post") schema_mock = mocker.patch( - "faculty.clients.experiment.client.DeleteExperimentRunsResponseSchema" + "faculty.clients.experiment._client.DeleteExperimentRunsResponseSchema" ) client = ExperimentClient(mocker.Mock()) @@ -557,10 +557,10 @@ def test_restore_runs(mocker): ExperimentClient, "_post", return_value=restore_runs_response ) response_schema_mock = mocker.patch( - "faculty.clients.experiment.client.RestoreExperimentRunsResponseSchema" + "faculty.clients.experiment._client.RestoreExperimentRunsResponseSchema" ) filter_schema_mock = mocker.patch( - "faculty.clients.experiment.client.FilterSchema" + "faculty.clients.experiment._client.FilterSchema" ) filter_dump_mock = filter_schema_mock.return_value.dump @@ -589,7 +589,7 @@ def test_restore_runs(mocker): def test_restore_runs_no_run_ids(mocker): mocker.patch.object(ExperimentClient, "_post") schema_mock = mocker.patch( - "faculty.clients.experiment.client.RestoreExperimentRunsResponseSchema" + "faculty.clients.experiment._client.RestoreExperimentRunsResponseSchema" ) client = ExperimentClient(mocker.Mock()) diff --git a/tests/clients/experiment/test_schemas.py b/tests/clients/experiment/test_schemas.py index e5e3066b..a750ff97 100644 --- a/tests/clients/experiment/test_schemas.py +++ b/tests/clients/experiment/test_schemas.py @@ -20,7 +20,7 @@ from marshmallow import ValidationError from pytz import UTC -from faculty.clients.experiment.models import ( +from faculty.clients.experiment._models import ( ComparisonOperator, CompoundFilter, DeleteExperimentRunsResponse, @@ -53,7 +53,7 @@ TagFilter, TagSort, ) -from faculty.clients.experiment.schemas import ( +from faculty.clients.experiment._schemas import ( CreateRunSchema, DeleteExperimentRunsResponseSchema, ExperimentRunDataSchema, From 305efd2e7f2e0a09188f3ed35f9775323e900651 Mon Sep 17 00:00:00 2001 From: Andrew Crozier Date: Wed, 12 Jun 2019 14:10:54 +0100 Subject: [PATCH 56/60] Fix test names --- tests/clients/test_experiment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index d6db90b0..3a3eb8ff 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -624,7 +624,7 @@ def test_filter_schema_nested(): @pytest.mark.parametrize( "order, expected_order", [(SortOrder.ASC, "asc"), (SortOrder.DESC, "desc")] ) -def test_sort_schema_no_tag(sort_type, by, order, expected_order): +def test_sort_schema_no_key(sort_type, by, order, expected_order): sort = sort_type(order) data = SortSchema().dump(sort) assert data == {"by": by, "order": expected_order} @@ -637,7 +637,7 @@ def test_sort_schema_no_tag(sort_type, by, order, expected_order): @pytest.mark.parametrize( "order, expected_order", [(SortOrder.ASC, "asc"), (SortOrder.DESC, "desc")] ) -def test_sort_schema_with_tag(sort_type, by, order, expected_order): +def test_sort_schema_with_key(sort_type, by, order, expected_order): sort = sort_type("sort-key", order) data = SortSchema().dump(sort) assert data == {"by": by, "key": "sort-key", "order": expected_order} From 8674556612c4650b02f47213b1352e06908198ac Mon Sep 17 00:00:00 2001 From: Hailey Fong Date: Wed, 12 Jun 2019 14:23:10 +0100 Subject: [PATCH 57/60] Reformat code according to flake8 --- tests/clients/experiment/test_client.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/clients/experiment/test_client.py b/tests/clients/experiment/test_client.py index d54e55f2..4476d9fa 100644 --- a/tests/clients/experiment/test_client.py +++ b/tests/clients/experiment/test_client.py @@ -557,7 +557,8 @@ def test_restore_runs(mocker): ExperimentClient, "_post", return_value=restore_runs_response ) response_schema_mock = mocker.patch( - "faculty.clients.experiment._client.RestoreExperimentRunsResponseSchema" + "faculty.clients.experiment._client." + "RestoreExperimentRunsResponseSchema" ) filter_schema_mock = mocker.patch( "faculty.clients.experiment._client.FilterSchema" @@ -589,7 +590,8 @@ def test_restore_runs(mocker): def test_restore_runs_no_run_ids(mocker): mocker.patch.object(ExperimentClient, "_post") schema_mock = mocker.patch( - "faculty.clients.experiment._client.RestoreExperimentRunsResponseSchema" + "faculty.clients.experiment._client." + "RestoreExperimentRunsResponseSchema" ) client = ExperimentClient(mocker.Mock()) From 9654c5c1d4d06c024443fee5be4fd14d23a922c4 Mon Sep 17 00:00:00 2001 From: Andrew Crozier Date: Wed, 12 Jun 2019 16:29:14 +0100 Subject: [PATCH 58/60] Use beta version on Python 2.7 / 3.4 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 1e427358..1e04d352 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ # compatible version of python-dateutil is available "marshmallow[reco]>=3.0.0rc3", "marshmallow_enum", - "marshmallow-oneofschema", + "marshmallow-oneofschema>=2.0.0b2", "boto3", "botocore", ], From 9b45e8c26cd426d6aff884125f82f11f08d228b6 Mon Sep 17 00:00:00 2001 From: Andrew Crozier Date: Wed, 12 Jun 2019 16:29:38 +0100 Subject: [PATCH 59/60] Fix Python 2 compatability --- faculty/clients/experiment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index cd243803..460c8a4e 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -318,7 +318,7 @@ class OptionalField(fields.Field): def __init__(self, nested, *args, **kwargs): self.nested = nested - super().__init__(*args, **kwargs) + super(OptionalField, self).__init__(*args, **kwargs) def _deserialize(self, value, *args, **kwargs): if value is None: From e05936a6e5bddb81fc878aece46deafad114ba31 Mon Sep 17 00:00:00 2001 From: Andrew Crozier Date: Thu, 13 Jun 2019 10:12:38 +0100 Subject: [PATCH 60/60] Fix Python 2.7 issue --- faculty/clients/experiment/_schemas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/faculty/clients/experiment/_schemas.py b/faculty/clients/experiment/_schemas.py index a0a9522b..173c2765 100644 --- a/faculty/clients/experiment/_schemas.py +++ b/faculty/clients/experiment/_schemas.py @@ -43,7 +43,7 @@ class _OptionalField(fields.Field): def __init__(self, nested, *args, **kwargs): self.nested = nested - super().__init__(*args, **kwargs) + super(_OptionalField, self).__init__(*args, **kwargs) def _deserialize(self, value, *args, **kwargs): if value is None: