diff --git a/faculty/clients/experiment/__init__.py b/faculty/clients/experiment/__init__.py new file mode 100644 index 00000000..62646f38 --- /dev/null +++ b/faculty/clients/experiment/__init__.py @@ -0,0 +1,46 @@ +# Copyright 2018-2019 Faculty Science Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from faculty.clients.experiment._models import ( # noqa: F401 + ComparisonOperator, + CompoundFilter, + DeletedAtFilter, + DurationSort, + Experiment, + ExperimentIdFilter, + ExperimentRun, + ExperimentRunStatus, + LifecycleStage, + LogicalOperator, + Metric, + MetricFilter, + MetricSort, + Param, + ParamFilter, + ParamSort, + ProjectIdFilter, + RunIdFilter, + RunNumberSort, + StartedAtSort, + Tag, + TagFilter, + TagSort, +) +from faculty.clients.experiment._client import ( # noqa: F401 + ExperimentClient, + ExperimentDeleted, + ExperimentNameConflict, + ParamConflict, +) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment/_client.py similarity index 62% rename from faculty/clients/experiment.py rename to faculty/clients/experiment/_client.py index 5e20866f..8689c41a 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment/_client.py @@ -12,13 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple -from enum import Enum -from marshmallow import fields, post_load -from marshmallow_enum import EnumField - -from faculty.clients.base import BaseClient, BaseSchema, Conflict +from faculty.clients.base import BaseClient, Conflict + +from faculty.clients.experiment._models import ( + ComparisonOperator, + CompoundFilter, + DeleteExperimentRunsResponse, + DeletedAtFilter, + ExperimentIdFilter, + LifecycleStage, + ListExperimentRunsResponse, + LogicalOperator, + Metric, + Page, + Pagination, + RestoreExperimentRunsResponse, + RunIdFilter, + RunQuery, +) +from faculty.clients.experiment._schemas import ( + CreateRunSchema, + DeleteExperimentRunsResponseSchema, + ExperimentRunDataSchema, + ExperimentRunInfoSchema, + ExperimentRunSchema, + ExperimentSchema, + FilterSchema, + ListExperimentRunsResponseSchema, + MetricHistorySchema, + RestoreExperimentRunsResponseSchema, + RunQuerySchema, +) class ExperimentNameConflict(Exception): @@ -43,244 +68,6 @@ def __init__(self, message, experiment_id): self.experiment_id = experiment_id -class ExperimentRunStatus(Enum): - RUNNING = "running" - FINISHED = "finished" - FAILED = "failed" - SCHEDULED = "scheduled" - KILLED = "killed" - - -Experiment = namedtuple( - "Experiment", - [ - "id", - "name", - "description", - "artifact_location", - "created_at", - "last_updated_at", - "deleted_at", - ], -) - - -ExperimentRun = namedtuple( - "ExperimentRun", - [ - "id", - "run_number", - "experiment_id", - "name", - "parent_run_id", - "artifact_location", - "status", - "started_at", - "ended_at", - "deleted_at", - "tags", - "params", - "metrics", - ], -) - -Metric = namedtuple("Metric", ["key", "value", "timestamp", "step"]) -Param = namedtuple("Param", ["key", "value"]) -Tag = namedtuple("Tag", ["key", "value"]) - -Page = namedtuple("Page", ["start", "limit"]) -Pagination = namedtuple("Pagination", ["start", "size", "previous", "next"]) -ListExperimentRunsResponse = namedtuple( - "ListExperimentRunsResponse", ["runs", "pagination"] -) -DeleteExperimentRunsResponse = namedtuple( - "DeleteExperimentRunsResponse", ["deleted_run_ids", "conflicted_run_ids"] -) -RestoreExperimentRunsResponse = namedtuple( - "RestoreExperimentRunsResponse", ["restored_run_ids", "conflicted_run_ids"] -) - -MetricDataPoint = namedtuple("Metric", ["value", "timestamp", "step"]) - -MetricHistory = namedtuple( - "MetricHistory", ["original_size", "subsampled", "key", "history"] -) - - -class MetricSchema(BaseSchema): - key = fields.String(required=True) - value = fields.Float(required=True) - timestamp = fields.DateTime(required=True) - step = fields.Integer(required=True) - - @post_load - def make_metric(self, data): - return Metric(**data) - - -class ParamSchema(BaseSchema): - key = fields.String(required=True) - value = fields.String(required=True) - - @post_load - def make_param(self, data): - return Param(**data) - - -class TagSchema(BaseSchema): - key = fields.String(required=True) - value = fields.String(required=True) - - @post_load - def make_tag(self, data): - return Tag(**data) - - -class LifecycleStage(Enum): - ACTIVE = "active" - DELETED = "deleted" - - -class ExperimentSchema(BaseSchema): - id = fields.Integer(data_key="experimentId", required=True) - name = fields.String(required=True) - description = fields.String(required=True) - artifact_location = fields.String( - data_key="artifactLocation", required=True - ) - created_at = fields.DateTime(data_key="createdAt", required=True) - last_updated_at = fields.DateTime(data_key="lastUpdatedAt", required=True) - deleted_at = fields.DateTime(data_key="deletedAt", missing=None) - - @post_load - def make_experiment(self, data): - return Experiment(**data) - - -class ExperimentRunSchema(BaseSchema): - id = fields.UUID(data_key="runId", required=True) - run_number = fields.Integer(data_key="runNumber", required=True) - experiment_id = fields.Integer(data_key="experimentId", required=True) - name = fields.String(required=True) - parent_run_id = fields.UUID(data_key="parentRunId", missing=None) - artifact_location = fields.String( - data_key="artifactLocation", required=True - ) - status = EnumField(ExperimentRunStatus, by_value=True, required=True) - started_at = fields.DateTime(data_key="startedAt", required=True) - ended_at = fields.DateTime(data_key="endedAt", missing=None) - deleted_at = fields.DateTime(data_key="deletedAt", missing=None) - tags = fields.Nested(TagSchema, many=True, required=True) - params = fields.Nested(ParamSchema, many=True, required=True) - metrics = fields.Nested(MetricSchema, many=True, required=True) - - @post_load - def make_experiment_run(self, data): - return ExperimentRun(**data) - - -class ExperimentRunDataSchema(BaseSchema): - metrics = fields.List(fields.Nested(MetricSchema)) - params = fields.List(fields.Nested(ParamSchema)) - tags = fields.List(fields.Nested(TagSchema)) - - -class ExperimentRunInfoSchema(BaseSchema): - status = EnumField(ExperimentRunStatus, by_value=True, required=True) - ended_at = fields.DateTime(data_key="endedAt", missing=None) - - -class PageSchema(BaseSchema): - start = fields.Integer(required=True) - limit = fields.Integer(required=True) - - @post_load - def make_page(self, data): - return Page(**data) - - -class PaginationSchema(BaseSchema): - start = fields.Integer(required=True) - size = fields.Integer(required=True) - previous = fields.Nested(PageSchema, missing=None) - next = fields.Nested(PageSchema, missing=None) - - @post_load - def make_pagination(self, data): - return Pagination(**data) - - -class ListExperimentRunsResponseSchema(BaseSchema): - pagination = fields.Nested(PaginationSchema, required=True) - runs = fields.Nested(ExperimentRunSchema, many=True, required=True) - - @post_load - def make_list_runs_response_schema(self, data): - return ListExperimentRunsResponse(**data) - - -class CreateRunSchema(BaseSchema): - name = fields.String() - parent_run_id = fields.UUID(data_key="parentRunId") - started_at = fields.DateTime(data_key="startedAt") - artifact_location = fields.String(data_key="artifactLocation") - tags = fields.Nested(TagSchema, many=True, required=True) - - -class DeleteExperimentRunsResponseSchema(BaseSchema): - deleted_run_ids = fields.List( - fields.UUID(), data_key="deletedRunIds", required=True - ) - conflicted_run_ids = fields.List( - fields.UUID(), data_key="conflictedRunIds", required=True - ) - - @post_load - def make_delete_runs_response(self, data): - return DeleteExperimentRunsResponse(**data) - - -class RestoreExperimentRunsResponseSchema(BaseSchema): - restored_run_ids = fields.List( - fields.UUID(), data_key="restoredRunIds", required=True - ) - conflicted_run_ids = fields.List( - fields.UUID(), data_key="conflictedRunIds", required=True - ) - - @post_load - def make_restore_runs_response(self, data): - return RestoreExperimentRunsResponse(**data) - - -class MetricDataPointSchema(BaseSchema): - """Deserialise a data point from the metric history endpoint. - - This schema is written with the expectation that it is not used alongside - the metric subsampling feature, which can result in null timestamp or step, - or a non-integer step. - """ - - value = fields.Float(required=True) - timestamp = fields.DateTime(required=True) - step = fields.Integer(required=True) - - @post_load - def make_metric(self, data): - return MetricDataPoint(**data) - - -class MetricHistorySchema(BaseSchema): - original_size = fields.Integer(data_key="originalSize", required=True) - subsampled = fields.Boolean(required=True) - key = fields.String(required=True) - history = fields.Nested(MetricDataPointSchema, many=True, required=True) - - @post_load - def make_history(self, data): - return MetricHistory(**data) - - class ExperimentClient(BaseClient): SERVICE_NAME = "atlas" @@ -526,6 +313,9 @@ def list_runs( To filter runs of experiments with the given IDs only. If an empty list is passed, a result with an empty list of runs is returned. By default, runs from all experiments are returned. + lifecycle_stage: LifecycleStage, optional + To filter runs of experiments in a specific lifecycle stage only. + By default, runs in any stage are returned. start : int, optional The (zero-indexed) starting point of runs to retrieve. limit : int, optional @@ -535,10 +325,11 @@ def list_runs( ------- ListExperimentRunsResponse """ - if lifecycle_stage is not None: - raise NotImplementedError("lifecycle_stage is not supported.") - query_params = [] + experiment_ids_filter = None + lifecycle_filter = None + filter = None + if experiment_ids is not None: if len(experiment_ids) == 0: return ListExperimentRunsResponse( @@ -547,17 +338,73 @@ def list_runs( start=0, size=0, previous=None, next=None ), ) - for experiment_id in experiment_ids: - query_params.append(("experimentId", experiment_id)) + experiment_id_filters = [ + ExperimentIdFilter(ComparisonOperator.EQUAL_TO, experiment_id) + for experiment_id in experiment_ids + ] + experiment_ids_filter = CompoundFilter( + LogicalOperator.OR, experiment_id_filters + ) + if lifecycle_stage is not None: + lifecycle_filter = DeletedAtFilter( + ComparisonOperator.DEFINED, + lifecycle_stage == LifecycleStage.DELETED, + ) - if start is not None: - query_params.append(("start", start)) - if limit is not None: - query_params.append(("limit", limit)) + if experiment_ids_filter is not None and lifecycle_filter is not None: + filter = CompoundFilter( + LogicalOperator.AND, [experiment_ids_filter, lifecycle_filter] + ) + elif experiment_ids_filter is not None: + filter = experiment_ids_filter + elif lifecycle_filter is not None: + filter = lifecycle_filter - endpoint = "/project/{}/run".format(project_id) - return self._get( - endpoint, ListExperimentRunsResponseSchema(), params=query_params + return self.query_runs(project_id, filter, None, start, limit) + + def query_runs( + self, project_id, filter=None, sort=None, start=None, limit=None + ): + """Query experiment runs. + + This method returns pages of runs. If less than the full number of runs + for the job is returned, the ``next`` page of the returned response + object will not be ``None``: + + >>> response = client.query_runs(project_id) + >>> response.pagination.next + Page(start=10, limit=10) + + Get all experiment runs by making successive calls to ``query_runs``, + passing the ``start`` and ``limit`` of the ``next`` page each time + until ``next`` is returned as ``None``. + + Parameters + ---------- + project_id : uuid.UUID + filter: SingleFilter or CompoundFilter, optional + To filter runs of experiments with the given filter. By default, + runs from all experiments are returned. + sort: List[Sort], optional + Runs are order using the conditions in sort. The relative + importance of each condition gradually decreases in order. + By default, experiment runs are sorted by their startedAt value. + start : int, optional + The (zero-indexed) starting point of runs to retrieve. + limit : int, optional + The maximum number of runs to retrieve. + + Returns + ------- + ListExperimentRunsResponse + """ + endpoint = "/project/{}/run/query".format(project_id) + page = None + if start is not None and limit is not None: + page = Page(start, limit) + payload = RunQuerySchema().dump(RunQuery(filter, sort, page)) + return self._post( + endpoint, ListExperimentRunsResponseSchema(), json=payload ) def log_run_data( @@ -673,15 +520,12 @@ def delete_runs(self, project_id, run_ids=None): deleted_run_ids=[], conflicted_run_ids=[] ) else: - payload = { - "filter": { - "operator": "or", - "conditions": [ - {"by": "runId", "operator": "eq", "value": str(run_id)} - for run_id in run_ids - ], - } - } + run_id_filters = [ + RunIdFilter(ComparisonOperator.EQUAL_TO, run_id) + for run_id in run_ids + ] + filter = CompoundFilter(LogicalOperator.OR, run_id_filters) + payload = {"filter": FilterSchema().dump(filter)} return self._post( endpoint, DeleteExperimentRunsResponseSchema(), json=payload @@ -714,15 +558,12 @@ def restore_runs(self, project_id, run_ids=None): restored_run_ids=[], conflicted_run_ids=[] ) else: - payload = { - "filter": { - "operator": "or", - "conditions": [ - {"by": "runId", "operator": "eq", "value": str(run_id)} - for run_id in run_ids - ], - } - } + run_id_filters = [ + RunIdFilter(ComparisonOperator.EQUAL_TO, run_id) + for run_id in run_ids + ] + filter = CompoundFilter(LogicalOperator.OR, run_id_filters) + payload = {"filter": FilterSchema().dump(filter)} return self._post( endpoint, RestoreExperimentRunsResponseSchema(), json=payload diff --git a/faculty/clients/experiment/_models.py b/faculty/clients/experiment/_models.py new file mode 100644 index 00000000..00b6fecc --- /dev/null +++ b/faculty/clients/experiment/_models.py @@ -0,0 +1,128 @@ +# Copyright 2018-2019 Faculty Science Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from collections import namedtuple +from enum import Enum + + +class LifecycleStage(Enum): + ACTIVE = "active" + DELETED = "deleted" + + +class ExperimentRunStatus(Enum): + RUNNING = "running" + FINISHED = "finished" + FAILED = "failed" + SCHEDULED = "scheduled" + KILLED = "killed" + + +Page = namedtuple("Page", ["start", "limit"]) +Pagination = namedtuple("Pagination", ["start", "size", "previous", "next"]) + +Metric = namedtuple("Metric", ["key", "value", "timestamp", "step"]) +Param = namedtuple("Param", ["key", "value"]) +Tag = namedtuple("Tag", ["key", "value"]) + +Experiment = namedtuple( + "Experiment", + [ + "id", + "name", + "description", + "artifact_location", + "created_at", + "last_updated_at", + "deleted_at", + ], +) + + +ExperimentRun = namedtuple( + "ExperimentRun", + [ + "id", + "run_number", + "experiment_id", + "name", + "parent_run_id", + "artifact_location", + "status", + "started_at", + "ended_at", + "deleted_at", + "tags", + "params", + "metrics", + ], +) + + +class ComparisonOperator(Enum): + DEFINED = "defined" + EQUAL_TO = "eq" + NOT_EQUAL_TO = "ne" + LESS_THAN = "lt" + LESS_THAN_OR_EQUAL_TO = "le" + GREATER_THAN = "gt" + GREATER_THAN_OR_EQUAL_TO = "ge" + + +ProjectIdFilter = namedtuple("ProjectIdFilter", ["operator", "value"]) +ExperimentIdFilter = namedtuple("ExperimentIdFilter", ["operator", "value"]) +RunIdFilter = namedtuple("RunIdFilter", ["operator", "value"]) +DeletedAtFilter = namedtuple("DeletedAtFilter", ["operator", "value"]) +TagFilter = namedtuple("TagFilter", ["key", "operator", "value"]) +ParamFilter = namedtuple("ParamFilter", ["key", "operator", "value"]) +MetricFilter = namedtuple("MetricFilter", ["key", "operator", "value"]) + + +class LogicalOperator(Enum): + AND = "and" + OR = "or" + + +CompoundFilter = namedtuple("CompoundFilter", ["operator", "conditions"]) + + +class SortOrder(Enum): + ASC = "asc" + DESC = "desc" + + +StartedAtSort = namedtuple("StartedAtSort", ["order"]) +RunNumberSort = namedtuple("RunNumberSort", ["order"]) +DurationSort = namedtuple("DurationSort", ["order"]) +TagSort = namedtuple("TagSort", ["key", "order"]) +ParamSort = namedtuple("ParamSort", ["key", "order"]) +MetricSort = namedtuple("MetricSort", ["key", "order"]) + +RunQuery = namedtuple("RunQuery", ["filter", "sort", "page"]) + +MetricDataPoint = namedtuple("Metric", ["value", "timestamp", "step"]) +MetricHistory = namedtuple( + "MetricHistory", ["original_size", "subsampled", "key", "history"] +) + +ListExperimentRunsResponse = namedtuple( + "ListExperimentRunsResponse", ["runs", "pagination"] +) +DeleteExperimentRunsResponse = namedtuple( + "DeleteExperimentRunsResponse", ["deleted_run_ids", "conflicted_run_ids"] +) +RestoreExperimentRunsResponse = namedtuple( + "RestoreExperimentRunsResponse", ["restored_run_ids", "conflicted_run_ids"] +) diff --git a/faculty/clients/experiment/_schemas.py b/faculty/clients/experiment/_schemas.py new file mode 100644 index 00000000..173c2765 --- /dev/null +++ b/faculty/clients/experiment/_schemas.py @@ -0,0 +1,418 @@ +# Copyright 2018-2019 Faculty Science Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from marshmallow import fields, post_load, pre_dump, ValidationError +from marshmallow_enum import EnumField +from marshmallow_oneofschema import OneOfSchema + +from faculty.clients.base import BaseSchema +from faculty.clients.experiment._models import ( + ComparisonOperator, + DeleteExperimentRunsResponse, + Experiment, + ExperimentRun, + ExperimentRunStatus, + ListExperimentRunsResponse, + LogicalOperator, + Metric, + MetricDataPoint, + MetricHistory, + Page, + Pagination, + Param, + RestoreExperimentRunsResponse, + SortOrder, + Tag, +) + + +class _OptionalField(fields.Field): + """Wrap another field, passing through Nones.""" + + def __init__(self, nested, *args, **kwargs): + self.nested = nested + super(_OptionalField, self).__init__(*args, **kwargs) + + def _deserialize(self, value, *args, **kwargs): + if value is None: + return None + else: + return self.nested._deserialize(value, *args, **kwargs) + + def _serialize(self, value, *args, **kwargs): + if value is None: + return None + else: + return self.nested._serialize(value, *args, **kwargs) + + +class _OneOfSchemaWithoutType(OneOfSchema): + def dump(self, *args, **kwargs): + data = super(_OneOfSchemaWithoutType, self).dump(*args, **kwargs) + # Remove the type field added by marshmallow-oneofschema + return {k: v for k, v in data.items() if k != "type"} + + +class PageSchema(BaseSchema): + start = fields.Integer(required=True) + limit = fields.Integer(required=True) + + @post_load + def make_page(self, data): + return Page(**data) + + +class PaginationSchema(BaseSchema): + start = fields.Integer(required=True) + size = fields.Integer(required=True) + previous = fields.Nested(PageSchema, missing=None) + next = fields.Nested(PageSchema, missing=None) + + @post_load + def make_pagination(self, data): + return Pagination(**data) + + +class MetricSchema(BaseSchema): + key = fields.String(required=True) + value = fields.Float(required=True) + timestamp = fields.DateTime(required=True) + step = fields.Integer(required=True) + + @post_load + def make_metric(self, data): + return Metric(**data) + + +class ParamSchema(BaseSchema): + key = fields.String(required=True) + value = fields.String(required=True) + + @post_load + def make_param(self, data): + return Param(**data) + + +class TagSchema(BaseSchema): + key = fields.String(required=True) + value = fields.String(required=True) + + @post_load + def make_tag(self, data): + return Tag(**data) + + +class ExperimentSchema(BaseSchema): + id = fields.Integer(data_key="experimentId", required=True) + name = fields.String(required=True) + description = fields.String(required=True) + artifact_location = fields.String( + data_key="artifactLocation", required=True + ) + created_at = fields.DateTime(data_key="createdAt", required=True) + last_updated_at = fields.DateTime(data_key="lastUpdatedAt", required=True) + deleted_at = fields.DateTime(data_key="deletedAt", missing=None) + + @post_load + def make_experiment(self, data): + return Experiment(**data) + + +class ExperimentRunSchema(BaseSchema): + id = fields.UUID(data_key="runId", required=True) + run_number = fields.Integer(data_key="runNumber", required=True) + experiment_id = fields.Integer(data_key="experimentId", required=True) + name = fields.String(required=True) + parent_run_id = fields.UUID(data_key="parentRunId", missing=None) + artifact_location = fields.String( + data_key="artifactLocation", required=True + ) + status = EnumField(ExperimentRunStatus, by_value=True, required=True) + started_at = fields.DateTime(data_key="startedAt", required=True) + ended_at = fields.DateTime(data_key="endedAt", missing=None) + deleted_at = fields.DateTime(data_key="deletedAt", missing=None) + tags = fields.Nested(TagSchema, many=True, required=True) + params = fields.Nested(ParamSchema, many=True, required=True) + metrics = fields.Nested(MetricSchema, many=True, required=True) + + @post_load + def make_experiment_run(self, data): + return ExperimentRun(**data) + + +# Schemas for payloads sent to API: + + +class ExperimentRunDataSchema(BaseSchema): + metrics = fields.List(fields.Nested(MetricSchema)) + params = fields.List(fields.Nested(ParamSchema)) + tags = fields.List(fields.Nested(TagSchema)) + + +class ExperimentRunInfoSchema(BaseSchema): + status = EnumField(ExperimentRunStatus, by_value=True, required=True) + ended_at = fields.DateTime(data_key="endedAt", missing=None) + + +class ListExperimentRunsResponseSchema(BaseSchema): + pagination = fields.Nested(PaginationSchema, required=True) + runs = fields.Nested(ExperimentRunSchema, many=True, required=True) + + @post_load + def make_list_runs_response_schema(self, data): + return ListExperimentRunsResponse(**data) + + +class CreateRunSchema(BaseSchema): + name = fields.String() + parent_run_id = fields.UUID(data_key="parentRunId") + started_at = fields.DateTime(data_key="startedAt") + artifact_location = fields.String(data_key="artifactLocation") + tags = fields.Nested(TagSchema, many=True, required=True) + + +class _ParamFilterValueField(fields.Field): + """Field that passes through strings or numbers.""" + + default_error_messages = { + "unsupported_type": "Param values must be of type str, int or float." + } + + def _serialize(self, value, attr, obj, **kwargs): + if isinstance(value, str): + field = fields.String() + elif isinstance(value, int) or isinstance(value, float): + field = fields.Number() + else: + self.fail("unsupported_type") + return field._serialize(value, attr, obj, **kwargs) + + +class _FilterValueField(fields.Field): + def __init__(self, other_field_type, *args, **kwargs): + self.other_field_type = other_field_type + super(_FilterValueField, self).__init__(*args, **kwargs) + + def _serialize(self, value, attr, obj, **kwargs): + if obj.operator == ComparisonOperator.DEFINED: + field_cls = fields.Boolean + else: + field_cls = self.other_field_type + return field_cls()._serialize(value, attr, obj, **kwargs) + + +def _validate_discrete(operator): + if operator not in { + ComparisonOperator.DEFINED, + ComparisonOperator.EQUAL_TO, + ComparisonOperator.NOT_EQUAL_TO, + }: + raise ValidationError({"operator": "Not a discrete operator."}) + + +class _ProjectIdFilterSchema(BaseSchema): + operator = EnumField(ComparisonOperator, by_value=True) + value = _FilterValueField(fields.UUID) + by = fields.Constant("projectId", dump_only=True) + + @pre_dump + def check_operator(self, obj): + _validate_discrete(obj.operator) + return obj + + +class _ExperimentIdFilterSchema(BaseSchema): + operator = EnumField(ComparisonOperator, by_value=True) + value = _FilterValueField(fields.Integer) + by = fields.Constant("experimentId", dump_only=True) + + @pre_dump + def check_operator(self, obj): + _validate_discrete(obj.operator) + return obj + + +class _RunIdFilterSchema(BaseSchema): + operator = EnumField(ComparisonOperator, by_value=True) + value = _FilterValueField(fields.UUID) + by = fields.Constant("runId", dump_only=True) + + @pre_dump + def check_operator(self, obj): + _validate_discrete(obj.operator) + return obj + + +class _DeletedAtFilterSchema(BaseSchema): + operator = EnumField(ComparisonOperator, by_value=True) + value = _FilterValueField(fields.DateTime) + by = fields.Constant("deletedAt", dump_only=True) + + +class _TagFilterSchema(BaseSchema): + key = fields.String() + operator = EnumField(ComparisonOperator, by_value=True) + value = _FilterValueField(fields.String) + by = fields.Constant("tag", dump_only=True) + + @pre_dump + def check_operator(self, obj): + _validate_discrete(obj.operator) + return obj + + +class _ParamFilterSchema(BaseSchema): + key = fields.String() + operator = EnumField(ComparisonOperator, by_value=True) + value = _FilterValueField(_ParamFilterValueField) + by = fields.Constant("param", dump_only=True) + + @pre_dump + def check_operator(self, obj): + if isinstance(obj.value, str): + _validate_discrete(obj.operator) + return obj + + +class _MetricFilterSchema(BaseSchema): + key = fields.String() + operator = EnumField(ComparisonOperator, by_value=True) + value = _FilterValueField(fields.Float) + by = fields.Constant("metric", dump_only=True) + + +class _CompoundFilterSchema(BaseSchema): + operator = EnumField(LogicalOperator, by_value=True) + conditions = fields.List(fields.Nested("FilterSchema")) + + +class FilterSchema(_OneOfSchemaWithoutType): + type_schemas = { + "ProjectIdFilter": _ProjectIdFilterSchema, + "ExperimentIdFilter": _ExperimentIdFilterSchema, + "RunIdFilter": _RunIdFilterSchema, + "DeletedAtFilter": _DeletedAtFilterSchema, + "TagFilter": _TagFilterSchema, + "ParamFilter": _ParamFilterSchema, + "MetricFilter": _MetricFilterSchema, + "CompoundFilter": _CompoundFilterSchema, + } + + +class _StartedAtSortSchema(BaseSchema): + order = EnumField(SortOrder, by_value=True) + by = fields.Constant("startedAt", dump_only=True) + + +class _RunNumberSortSchema(BaseSchema): + order = EnumField(SortOrder, by_value=True) + by = fields.Constant("runNumber", dump_only=True) + + +class _DurationSortSchema(BaseSchema): + order = EnumField(SortOrder, by_value=True) + by = fields.Constant("duration", dump_only=True) + + +class _TagSortSchema(BaseSchema): + key = fields.String() + order = EnumField(SortOrder, by_value=True) + by = fields.Constant("tag", dump_only=True) + + +class _ParamSortSchema(BaseSchema): + key = fields.String() + order = EnumField(SortOrder, by_value=True) + by = fields.Constant("param", dump_only=True) + + +class _MetricSortSchema(BaseSchema): + key = fields.String() + order = EnumField(SortOrder, by_value=True) + by = fields.Constant("metric", dump_only=True) + + +class SortSchema(_OneOfSchemaWithoutType): + type_schemas = { + "StartedAtSort": _StartedAtSortSchema, + "RunNumberSort": _RunNumberSortSchema, + "DurationSort": _DurationSortSchema, + "TagSort": _TagSortSchema, + "ParamSort": _ParamSortSchema, + "MetricSort": _MetricSortSchema, + } + + +class RunQuerySchema(BaseSchema): + filter = _OptionalField(fields.Nested(FilterSchema)) + sort = fields.List(fields.Nested(SortSchema)) + page = fields.Nested(PageSchema, missing=None) + + +# Schemas for responses returned from API: + + +class DeleteExperimentRunsResponseSchema(BaseSchema): + deleted_run_ids = fields.List( + fields.UUID(), data_key="deletedRunIds", required=True + ) + conflicted_run_ids = fields.List( + fields.UUID(), data_key="conflictedRunIds", required=True + ) + + @post_load + def make_delete_runs_response(self, data): + return DeleteExperimentRunsResponse(**data) + + +class RestoreExperimentRunsResponseSchema(BaseSchema): + restored_run_ids = fields.List( + fields.UUID(), data_key="restoredRunIds", required=True + ) + conflicted_run_ids = fields.List( + fields.UUID(), data_key="conflictedRunIds", required=True + ) + + @post_load + def make_restore_runs_response(self, data): + return RestoreExperimentRunsResponse(**data) + + +class _MetricDataPointSchema(BaseSchema): + """Deserialise a data point from the metric history endpoint. + + This schema is written with the expectation that it is not used alongside + the metric subsampling feature, which can result in null timestamp or step, + or a non-integer step. + """ + + value = fields.Float(required=True) + timestamp = fields.DateTime(required=True) + step = fields.Integer(required=True) + + @post_load + def make_metric(self, data): + return MetricDataPoint(**data) + + +class MetricHistorySchema(BaseSchema): + original_size = fields.Integer(data_key="originalSize", required=True) + subsampled = fields.Boolean(required=True) + key = fields.String(required=True) + history = fields.Nested(_MetricDataPointSchema, many=True, required=True) + + @post_load + def make_history(self, data): + return MetricHistory(**data) diff --git a/setup.py b/setup.py index a8772681..1e04d352 100644 --- a/setup.py +++ b/setup.py @@ -35,6 +35,7 @@ # compatible version of python-dateutil is available "marshmallow[reco]>=3.0.0rc3", "marshmallow_enum", + "marshmallow-oneofschema>=2.0.0b2", "boto3", "botocore", ], diff --git a/tests/clients/experiment/__init__.py b/tests/clients/experiment/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/clients/experiment/test_client.py b/tests/clients/experiment/test_client.py new file mode 100644 index 00000000..4476d9fa --- /dev/null +++ b/tests/clients/experiment/test_client.py @@ -0,0 +1,615 @@ +# Copyright 2018-2019 Faculty Science Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from uuid import uuid4 + +import pytest + +from faculty.clients.base import Conflict +from faculty.clients.experiment._client import ( + ExperimentClient, + ExperimentDeleted, + ExperimentNameConflict, + ParamConflict, +) +from faculty.clients.experiment._models import ( + ComparisonOperator, + CompoundFilter, + DeletedAtFilter, + ExperimentIdFilter, + LifecycleStage, + LogicalOperator, + Metric, + Page, + RunIdFilter, + RunQuery, +) + + +PROJECT_ID = uuid4() +EXPERIMENT_ID = 234 +EXPERIMENT_RUN_ID = uuid4() +PARENT_RUN_ID = uuid4() + + +@pytest.mark.parametrize("description", [None, "experiment description"]) +@pytest.mark.parametrize("artifact_location", [None, "s3://mybucket"]) +def test_experiment_client_create(mocker, description, artifact_location): + experiment = mocker.Mock() + mocker.patch.object(ExperimentClient, "_post", return_value=experiment) + schema_mock = mocker.patch( + "faculty.clients.experiment._client.ExperimentSchema" + ) + + client = ExperimentClient(mocker.Mock()) + returned_experiment = client.create( + PROJECT_ID, "experiment name", description, artifact_location + ) + assert returned_experiment == experiment + + schema_mock.assert_called_once_with() + ExperimentClient._post.assert_called_once_with( + "/project/{}/experiment".format(PROJECT_ID), + schema_mock.return_value, + json={ + "name": "experiment name", + "description": description, + "artifactLocation": artifact_location, + }, + ) + + +def test_experiment_client_create_name_conflict(mocker): + error_code = "experiment_name_conflict" + exception = Conflict(mocker.Mock(), mocker.Mock(), error_code) + mocker.patch.object(ExperimentClient, "_post", side_effect=exception) + + client = ExperimentClient(mocker.Mock()) + with pytest.raises( + ExperimentNameConflict, match="name 'experiment name' already exists" + ): + client.create(PROJECT_ID, "experiment name") + + +def test_experiment_client_get(mocker): + experiment = mocker.Mock() + mocker.patch.object(ExperimentClient, "_get", return_value=experiment) + schema_mock = mocker.patch( + "faculty.clients.experiment._client.ExperimentSchema" + ) + + client = ExperimentClient(mocker.Mock()) + returned_experiment = client.get(PROJECT_ID, EXPERIMENT_ID) + assert returned_experiment == experiment + + schema_mock.assert_called_once_with() + ExperimentClient._get.assert_called_once_with( + "/project/{}/experiment/{}".format(PROJECT_ID, EXPERIMENT_ID), + schema_mock.return_value, + ) + + +def test_experiment_client_list(mocker): + experiment = mocker.Mock() + mocker.patch.object(ExperimentClient, "_get", return_value=[experiment]) + schema_mock = mocker.patch( + "faculty.clients.experiment._client.ExperimentSchema" + ) + + client = ExperimentClient(mocker.Mock()) + assert client.list(PROJECT_ID) == [experiment] + + schema_mock.assert_called_once_with(many=True) + ExperimentClient._get.assert_called_once_with( + "/project/{}/experiment".format(PROJECT_ID), + schema_mock.return_value, + params={}, + ) + + +def test_experiment_client_list_lifecycle_filter(mocker): + experiment = mocker.Mock() + mocker.patch.object(ExperimentClient, "_get", return_value=[experiment]) + schema_mock = mocker.patch( + "faculty.clients.experiment._client.ExperimentSchema" + ) + + client = ExperimentClient(mocker.Mock()) + returned_experiments = client.list( + PROJECT_ID, lifecycle_stage=LifecycleStage.ACTIVE + ) + assert returned_experiments == [experiment] + + schema_mock.assert_called_once_with(many=True) + ExperimentClient._get.assert_called_once_with( + "/project/{}/experiment".format(PROJECT_ID), + schema_mock.return_value, + params={"lifecycleStage": "active"}, + ) + + +@pytest.mark.parametrize("name", [None, "new name"]) +@pytest.mark.parametrize("description", [None, "new description"]) +def test_experiment_client_update(mocker, name, description): + mocker.patch.object(ExperimentClient, "_patch_raw") + + client = ExperimentClient(mocker.Mock()) + client.update( + PROJECT_ID, EXPERIMENT_ID, name=name, description=description + ) + + ExperimentClient._patch_raw.assert_called_once_with( + "/project/{}/experiment/{}".format(PROJECT_ID, EXPERIMENT_ID), + json={"name": name, "description": description}, + ) + + +def test_experiment_client_update_name_conflict(mocker): + error_code = "experiment_name_conflict" + exception = Conflict(mocker.Mock(), mocker.Mock(), error_code) + mocker.patch.object(ExperimentClient, "_patch_raw", side_effect=exception) + + client = ExperimentClient(mocker.Mock()) + with pytest.raises( + ExperimentNameConflict, match="name 'new name' already exists" + ): + client.update(PROJECT_ID, EXPERIMENT_ID, name="new name") + + +def test_delete(mocker): + mocker.patch.object(ExperimentClient, "_delete_raw") + + client = ExperimentClient(mocker.Mock()) + client.delete(PROJECT_ID, EXPERIMENT_ID) + + ExperimentClient._delete_raw.assert_called_once_with( + "/project/{}/experiment/{}".format(PROJECT_ID, EXPERIMENT_ID) + ) + + +def test_restore(mocker): + mocker.patch.object(ExperimentClient, "_put_raw") + + client = ExperimentClient(mocker.Mock()) + client.restore(PROJECT_ID, EXPERIMENT_ID) + + ExperimentClient._put_raw.assert_called_once_with( + "/project/{}/experiment/{}/restore".format(PROJECT_ID, EXPERIMENT_ID) + ) + + +def test_experiment_create_run(mocker): + run = mocker.Mock() + mocker.patch.object(ExperimentClient, "_post", return_value=run) + request_schema_mock = mocker.patch( + "faculty.clients.experiment._client.CreateRunSchema" + ) + dump_mock = request_schema_mock.return_value.dump + response_schema_mock = mocker.patch( + "faculty.clients.experiment._client.ExperimentRunSchema" + ) + run_name = mocker.Mock() + started_at = mocker.Mock() + artifact_location = mocker.Mock() + + client = ExperimentClient(mocker.Mock()) + returned_run = client.create_run( + PROJECT_ID, + EXPERIMENT_ID, + run_name, + started_at, + PARENT_RUN_ID, + artifact_location=artifact_location, + ) + assert returned_run == run + + request_schema_mock.assert_called_once_with() + dump_mock.assert_called_once_with( + { + "name": run_name, + "parent_run_id": PARENT_RUN_ID, + "started_at": started_at, + "artifact_location": artifact_location, + "tags": [], + } + ) + response_schema_mock.assert_called_once_with() + ExperimentClient._post.assert_called_once_with( + "/project/{}/experiment/{}/run".format(PROJECT_ID, EXPERIMENT_ID), + response_schema_mock.return_value, + json=dump_mock.return_value, + ) + + +def test_experiment_create_run_experiment_deleted_conflict(mocker): + message = "experiment deleted" + error_code = "experiment_deleted" + response_mock = mocker.Mock() + response_mock.json.return_value = {"experimentId": 42} + exception = Conflict(response_mock, message, error_code) + + mocker.patch.object(ExperimentClient, "_post", side_effect=exception) + + client = ExperimentClient(mocker.Mock()) + with pytest.raises(ExperimentDeleted, match=message): + client.create_run( + PROJECT_ID, + EXPERIMENT_ID, + name=mocker.Mock(), + started_at=mocker.Mock(), + parent_run_id=PARENT_RUN_ID, + artifact_location=mocker.Mock(), + ) + + +def test_experiment_client_get_run(mocker): + run = mocker.Mock() + mocker.patch.object(ExperimentClient, "_get", return_value=run) + schema_mock = mocker.patch( + "faculty.clients.experiment._client.ExperimentRunSchema" + ) + + client = ExperimentClient(mocker.Mock()) + returned_run = client.get_run(PROJECT_ID, EXPERIMENT_RUN_ID) + assert returned_run == run + + schema_mock.assert_called_once_with() + ExperimentClient._get.assert_called_once_with( + "/project/{}/run/{}".format(PROJECT_ID, EXPERIMENT_RUN_ID), + schema_mock.return_value, + ) + + +def test_experiment_client_list_runs(mocker): + mocker.patch.object(ExperimentClient, "query_runs") + + client = ExperimentClient(mocker.Mock()) + response = client.list_runs( + PROJECT_ID, + experiment_ids=[123, 456], + lifecycle_stage=LifecycleStage.DELETED, + start=20, + limit=10, + ) + + assert response == ExperimentClient.query_runs.return_value + expected_filter = CompoundFilter( + LogicalOperator.AND, + [ + CompoundFilter( + LogicalOperator.OR, + [ + ExperimentIdFilter(ComparisonOperator.EQUAL_TO, 123), + ExperimentIdFilter(ComparisonOperator.EQUAL_TO, 456), + ], + ), + DeletedAtFilter(ComparisonOperator.DEFINED, True), + ], + ) + ExperimentClient.query_runs.assert_called_once_with( + PROJECT_ID, expected_filter, None, 20, 10 + ) + + +def test_experiment_client_list_runs_defaults(mocker): + mocker.patch.object(ExperimentClient, "query_runs") + + client = ExperimentClient(mocker.Mock()) + response = client.list_runs(PROJECT_ID) + + assert response == ExperimentClient.query_runs.return_value + ExperimentClient.query_runs.assert_called_once_with( + PROJECT_ID, None, None, None, None + ) + + +def test_experiment_client_query_runs(mocker): + list_response = mocker.Mock() + mocker.patch.object(ExperimentClient, "_post", return_value=list_response) + response_schema_mock = mocker.patch( + "faculty.clients.experiment._client.ListExperimentRunsResponseSchema" + ) + request_schema_mock = mocker.patch( + "faculty.clients.experiment._client.RunQuerySchema" + ) + request_dump_mock = request_schema_mock.return_value.dump + + filter = mocker.Mock() + sort = mocker.Mock() + + client = ExperimentClient(mocker.Mock()) + list_result = client.query_runs( + PROJECT_ID, filter, sort, start=20, limit=10 + ) + + assert list_result == list_response + + request_dump_mock.assert_called_once_with( + RunQuery(filter, sort, Page(20, 10)) + ) + response_schema_mock.assert_called_once_with() + ExperimentClient._post.assert_called_once_with( + "/project/{}/run/query".format(PROJECT_ID), + response_schema_mock.return_value, + json=request_dump_mock.return_value, + ) + + +def test_log_run_data(mocker): + mocker.patch.object(ExperimentClient, "_patch_raw") + run_data_schema_mock = mocker.patch( + "faculty.clients.experiment._client.ExperimentRunDataSchema" + ) + run_data_dump_mock = run_data_schema_mock.return_value.dump + + metric = mocker.Mock() + param = mocker.Mock() + tag = mocker.Mock() + + client = ExperimentClient(mocker.Mock()) + client.log_run_data( + PROJECT_ID, + EXPERIMENT_RUN_ID, + metrics=[metric], + params=[param], + tags=[tag], + ) + + run_data_schema_mock.assert_called_once_with() + run_data_dump_mock.assert_called_once_with( + {"metrics": [metric], "params": [param], "tags": [tag]} + ) + ExperimentClient._patch_raw.assert_called_once_with( + "/project/{}/run/{}/data".format(PROJECT_ID, EXPERIMENT_RUN_ID), + json=run_data_dump_mock.return_value, + ) + + +def test_log_run_data_param_conflict(mocker): + message = "bad params" + error_code = "conflicting_params" + response_mock = mocker.Mock() + response_mock.json.return_value = {"parameterKeys": ["bad-key"]} + exception = Conflict(response_mock, message, error_code) + + mocker.patch.object(ExperimentClient, "_patch_raw", side_effect=exception) + + client = ExperimentClient(mocker.Mock()) + + with pytest.raises(ParamConflict, match=message): + client.log_run_data( + PROJECT_ID, EXPERIMENT_RUN_ID, params=[mocker.Mock()] + ) + + +def test_log_run_data_other_conflict(mocker): + response_mock = mocker.Mock() + exception = Conflict(response_mock, "", "") + + mocker.patch.object(ExperimentClient, "_patch_raw", side_effect=exception) + client = ExperimentClient(mocker.Mock()) + + with pytest.raises(Conflict): + client.log_run_data( + PROJECT_ID, EXPERIMENT_RUN_ID, params=[mocker.Mock()] + ) + + +def test_log_run_data_empty(mocker): + mocker.patch.object(ExperimentClient, "_patch_raw") + + client = ExperimentClient(mocker.Mock()) + + client.log_run_data(PROJECT_ID, EXPERIMENT_RUN_ID) + ExperimentClient._patch_raw.assert_not_called() + + +def test_update_run_info(mocker): + run = mocker.Mock() + mocker.patch.object(ExperimentClient, "_patch", return_value=run) + run_schema_mock = mocker.patch( + "faculty.clients.experiment._client.ExperimentRunSchema" + ) + run_info_schema_mock = mocker.patch( + "faculty.clients.experiment._client.ExperimentRunInfoSchema" + ) + run_info_dump_mock = run_info_schema_mock.return_value.dump + + status = mocker.Mock() + ended_at = mocker.Mock() + + client = ExperimentClient(mocker.Mock()) + returned_run = client.update_run_info( + PROJECT_ID, EXPERIMENT_RUN_ID, status, ended_at + ) + assert returned_run == run + + run_schema_mock.assert_called_once_with() + run_info_schema_mock.assert_called_once_with() + run_info_dump_mock.assert_called_once_with( + {"status": status, "ended_at": ended_at} + ) + ExperimentClient._patch.assert_called_once_with( + "/project/{}/run/{}/info".format(PROJECT_ID, EXPERIMENT_RUN_ID), + run_schema_mock.return_value, + json=run_info_dump_mock.return_value, + ) + + +def test_get_metric_history(mocker): + key = mocker.Mock() + data_point_0 = mocker.Mock() + data_point_1 = mocker.Mock() + metric_history = mocker.Mock(key=key, history=[data_point_0, data_point_1]) + + mocker.patch.object(ExperimentClient, "_get", return_value=metric_history) + metric_history_schema_mock = mocker.patch( + "faculty.clients.experiment._client.MetricHistorySchema" + ) + + client = ExperimentClient(mocker.Mock()) + metrics = client.get_metric_history( + PROJECT_ID, EXPERIMENT_RUN_ID, "metric-key" + ) + + expected = [ + Metric( + key=key, + step=data_point_0.step, + timestamp=data_point_0.timestamp, + value=data_point_0.value, + ), + Metric( + key=key, + step=data_point_1.step, + timestamp=data_point_1.timestamp, + value=data_point_1.value, + ), + ] + assert metrics == expected + + metric_history_schema_mock.assert_called_once_with() + ExperimentClient._get.assert_called_once_with( + "/project/{}/run/{}/metric/metric-key/history".format( + PROJECT_ID, EXPERIMENT_RUN_ID + ), + metric_history_schema_mock.return_value, + ) + + +def test_delete_runs(mocker): + delete_runs_response = mocker.Mock() + mocker.patch.object( + ExperimentClient, "_post", return_value=delete_runs_response + ) + response_schema_mock = mocker.patch( + "faculty.clients.experiment._client.DeleteExperimentRunsResponseSchema" + ) + filter_schema_mock = mocker.patch( + "faculty.clients.experiment._client.FilterSchema" + ) + filter_dump_mock = filter_schema_mock.return_value.dump + + run_ids = [uuid4(), uuid4()] + + client = ExperimentClient(mocker.Mock()) + response = client.delete_runs(PROJECT_ID, run_ids) + + assert response == delete_runs_response + + expected_filter = CompoundFilter( + LogicalOperator.OR, + [ + RunIdFilter(ComparisonOperator.EQUAL_TO, run_ids[0]), + RunIdFilter(ComparisonOperator.EQUAL_TO, run_ids[1]), + ], + ) + filter_dump_mock.assert_called_once_with(expected_filter) + ExperimentClient._post.assert_called_once_with( + "/project/{}/run/delete/query".format(PROJECT_ID), + response_schema_mock.return_value, + json={"filter": filter_dump_mock.return_value}, + ) + + +def test_delete_runs_no_run_ids(mocker): + mocker.patch.object(ExperimentClient, "_post") + schema_mock = mocker.patch( + "faculty.clients.experiment._client.DeleteExperimentRunsResponseSchema" + ) + + client = ExperimentClient(mocker.Mock()) + client.delete_runs(PROJECT_ID) + + ExperimentClient._post.assert_called_once_with( + "/project/{}/run/delete/query".format(PROJECT_ID), + schema_mock.return_value, + json={}, + ) + + +def test_delete_runs_empty_list(mocker): + mocker.patch.object(ExperimentClient, "_post") + + client = ExperimentClient(mocker.Mock()) + response = client.delete_runs(PROJECT_ID, run_ids=[]) + + ExperimentClient._post.assert_not_called() + assert len(response.deleted_run_ids) == 0 + assert len(response.conflicted_run_ids) == 0 + + +def test_restore_runs(mocker): + restore_runs_response = mocker.Mock() + mocker.patch.object( + ExperimentClient, "_post", return_value=restore_runs_response + ) + response_schema_mock = mocker.patch( + "faculty.clients.experiment._client." + "RestoreExperimentRunsResponseSchema" + ) + filter_schema_mock = mocker.patch( + "faculty.clients.experiment._client.FilterSchema" + ) + filter_dump_mock = filter_schema_mock.return_value.dump + + run_ids = [uuid4(), uuid4()] + + client = ExperimentClient(mocker.Mock()) + response = client.restore_runs(PROJECT_ID, run_ids) + + assert response == restore_runs_response + + expected_filter = CompoundFilter( + LogicalOperator.OR, + [ + RunIdFilter(ComparisonOperator.EQUAL_TO, run_ids[0]), + RunIdFilter(ComparisonOperator.EQUAL_TO, run_ids[1]), + ], + ) + filter_dump_mock.assert_called_once_with(expected_filter) + ExperimentClient._post.assert_called_once_with( + "/project/{}/run/restore/query".format(PROJECT_ID), + response_schema_mock.return_value, + json={"filter": filter_dump_mock.return_value}, + ) + + +def test_restore_runs_no_run_ids(mocker): + mocker.patch.object(ExperimentClient, "_post") + schema_mock = mocker.patch( + "faculty.clients.experiment._client." + "RestoreExperimentRunsResponseSchema" + ) + + client = ExperimentClient(mocker.Mock()) + client.restore_runs(PROJECT_ID) + + ExperimentClient._post.assert_called_once_with( + "/project/{}/run/restore/query".format(PROJECT_ID), + schema_mock.return_value, + json={}, + ) + + +def test_restore_runs_empty_list(mocker): + mocker.patch.object(ExperimentClient, "_post") + + client = ExperimentClient(mocker.Mock()) + response = client.restore_runs(PROJECT_ID, run_ids=[]) + + ExperimentClient._post.assert_not_called() + assert len(response.restored_run_ids) == 0 + assert len(response.conflicted_run_ids) == 0 diff --git a/tests/clients/experiment/test_schemas.py b/tests/clients/experiment/test_schemas.py new file mode 100644 index 00000000..a6cb2e4a --- /dev/null +++ b/tests/clients/experiment/test_schemas.py @@ -0,0 +1,728 @@ +# Copyright 2018-2019 Faculty Science Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from datetime import datetime +from uuid import uuid4 + +import pytest +from marshmallow import ValidationError +from pytz import UTC + +from faculty.clients.experiment._models import ( + ComparisonOperator, + CompoundFilter, + DeleteExperimentRunsResponse, + DeletedAtFilter, + DurationSort, + Experiment, + ExperimentIdFilter, + ExperimentRun, + ExperimentRunStatus, + ListExperimentRunsResponse, + LogicalOperator, + Metric, + MetricDataPoint, + MetricFilter, + MetricHistory, + MetricSort, + Page, + Pagination, + Param, + ParamFilter, + ParamSort, + ProjectIdFilter, + RestoreExperimentRunsResponse, + RunIdFilter, + RunNumberSort, + RunQuery, + SortOrder, + StartedAtSort, + Tag, + TagFilter, + TagSort, +) +from faculty.clients.experiment._schemas import ( + CreateRunSchema, + DeleteExperimentRunsResponseSchema, + ExperimentRunDataSchema, + ExperimentRunSchema, + ExperimentSchema, + FilterSchema, + ListExperimentRunsResponseSchema, + MetricHistorySchema, + MetricSchema, + PageSchema, + PaginationSchema, + ParamSchema, + RestoreExperimentRunsResponseSchema, + RunQuerySchema, + SortSchema, + TagSchema, +) + +PROJECT_ID = uuid4() +EXPERIMENT_ID = 661 +EXPERIMENT_RUN_ID = uuid4() +EXPERIMENT_RUN_NUMBER = 3 +EXPERIMENT_RUN_NAME = "run name" +PARENT_RUN_ID = uuid4() +CREATED_AT = datetime(2018, 3, 10, 11, 32, 6, 247000, tzinfo=UTC) +CREATED_AT_STRING = "2018-03-10T11:32:06.247Z" +LAST_UPDATED_AT = datetime(2018, 3, 10, 11, 32, 30, 172000, tzinfo=UTC) +LAST_UPDATED_AT_STRING = "2018-03-10T11:32:30.172Z" +DELETED_AT = datetime(2018, 3, 10, 11, 37, 42, 482000, tzinfo=UTC) +DELETED_AT_STRING = "2018-03-10T11:37:42.482Z" +DELETED_AT_STRING_PYTHON = "2018-03-10T11:37:42.482000+00:00" + +EXPERIMENT = Experiment( + id=EXPERIMENT_ID, + name="experiment name", + description="experiment description", + artifact_location="https://example.com", + created_at=CREATED_AT, + last_updated_at=LAST_UPDATED_AT, + deleted_at=DELETED_AT, +) +EXPERIMENT_BODY = { + "experimentId": EXPERIMENT_ID, + "name": EXPERIMENT.name, + "description": EXPERIMENT.description, + "artifactLocation": EXPERIMENT.artifact_location, + "createdAt": CREATED_AT_STRING, + "lastUpdatedAt": LAST_UPDATED_AT_STRING, + "deletedAt": DELETED_AT_STRING, +} + +RUN_ID = uuid4() +RUN_STARTED_AT = datetime(2018, 3, 10, 11, 39, 12, 110000, tzinfo=UTC) +RUN_STARTED_AT_NO_TIMEZONE = datetime(2018, 3, 10, 11, 39, 12, 110000) +RUN_STARTED_AT_STRING_PYTHON = "2018-03-10T11:39:12.110000+00:00" +RUN_STARTED_AT_STRING_JAVA = "2018-03-10T11:39:12.11Z" +RUN_ENDED_AT = datetime(2018, 3, 10, 11, 39, 15, 110000, tzinfo=UTC) +RUN_ENDED_AT_STRING = "2018-03-10T11:39:15.11Z" + +TAG = Tag(key="tag-key", value="tag-value") +TAG_BODY = {"key": "tag-key", "value": "tag-value"} + +OTHER_TAG = Tag(key="other-tag-key", value="other-tag-value") +OTHER_TAG_BODY = {"key": "other-tag-key", "value": "other-tag-value"} + +PARAM = Param(key="param-key", value="param-value") +PARAM_BODY = {"key": "param-key", "value": "param-value"} + +METRIC_KEY = "metric-key" +METRIC = Metric( + key=METRIC_KEY, + value=123.0, + timestamp=datetime(2018, 3, 12, 16, 20, 22, 122000, tzinfo=UTC), + step=0, +) +METRIC_BODY = { + "key": METRIC.key, + "value": METRIC.value, + "timestamp": "2018-03-12T16:20:22.122000+00:00", + "step": METRIC.step, +} + +METRIC_DATA_POINT = MetricDataPoint( + value=METRIC.value, timestamp=METRIC.timestamp, step=METRIC.step +) +METRIC_DATA_POINT_BODY = { + "value": METRIC_BODY["value"], + "timestamp": METRIC_BODY["timestamp"], + "step": METRIC_BODY["step"], +} + +METRIC_HISTORY = MetricHistory( + original_size=1, + subsampled=False, + key=METRIC_KEY, + history=[METRIC_DATA_POINT], +) +METRIC_HISTORY_BODY = { + "originalSize": METRIC_HISTORY.original_size, + "subsampled": METRIC_HISTORY.subsampled, + "key": METRIC_HISTORY.key, + "history": [METRIC_DATA_POINT_BODY], +} + +EXPERIMENT_RUN = ExperimentRun( + id=EXPERIMENT_RUN_ID, + run_number=EXPERIMENT_RUN_NUMBER, + name=EXPERIMENT_RUN_NAME, + parent_run_id=PARENT_RUN_ID, + experiment_id=EXPERIMENT.id, + artifact_location="faculty:", + status=ExperimentRunStatus.RUNNING, + started_at=RUN_STARTED_AT, + ended_at=RUN_ENDED_AT, + deleted_at=DELETED_AT, + tags=[TAG], + params=[PARAM], + metrics=[METRIC], +) +EXPERIMENT_RUN_BODY = { + "experimentId": EXPERIMENT.id, + "runId": str(EXPERIMENT_RUN_ID), + "runNumber": EXPERIMENT_RUN_NUMBER, + "name": EXPERIMENT_RUN_NAME, + "parentRunId": str(PARENT_RUN_ID), + "artifactLocation": "faculty:", + "status": "running", + "startedAt": RUN_STARTED_AT_STRING_JAVA, + "endedAt": RUN_ENDED_AT_STRING, + "deletedAt": DELETED_AT_STRING, + "tags": [TAG_BODY], + "metrics": [METRIC_BODY], + "params": [PARAM_BODY], +} + +EXPERIMENT_RUN_DATA_BODY = { + "metrics": [METRIC_BODY], + "params": [PARAM_BODY], + "tags": [TAG_BODY], +} + + +PAGE = Page(start=3, limit=10) +PAGE_BODY = {"start": PAGE.start, "limit": PAGE.limit} + +PAGINATION = Pagination( + start=20, + size=10, + previous=Page(start=10, limit=10), + next=Page(start=30, limit=10), +) +PAGINATION_BODY = { + "start": PAGINATION.start, + "size": PAGINATION.size, + "previous": { + "start": PAGINATION.previous.start, + "limit": PAGINATION.previous.limit, + }, + "next": {"start": PAGINATION.next.start, "limit": PAGINATION.next.limit}, +} + +LIST_EXPERIMENT_RUNS_RESPONSE = ListExperimentRunsResponse( + runs=[EXPERIMENT_RUN], pagination=PAGINATION +) +LIST_EXPERIMENT_RUNS_RESPONSE_BODY = { + "runs": [EXPERIMENT_RUN_BODY], + "pagination": PAGINATION_BODY, +} + +DELETE_EXPERIMENT_RUNS_RESPONSE = DeleteExperimentRunsResponse( + deleted_run_ids=[uuid4(), uuid4()], conflicted_run_ids=[uuid4(), uuid4()] +) +DELETE_EXPERIMENT_RUNS_RESPONSE_BODY = { + "deletedRunIds": [ + str(run_id) + for run_id in DELETE_EXPERIMENT_RUNS_RESPONSE.deleted_run_ids + ], + "conflictedRunIds": [ + str(run_id) + for run_id in DELETE_EXPERIMENT_RUNS_RESPONSE.conflicted_run_ids + ], +} + +RESTORE_EXPERIMENT_RUNS_RESPONSE = RestoreExperimentRunsResponse( + restored_run_ids=[uuid4(), uuid4()], conflicted_run_ids=[uuid4(), uuid4()] +) +RESTORE_EXPERIMENT_RUNS_RESPONSE_BODY = { + "restoredRunIds": [ + str(run_id) + for run_id in RESTORE_EXPERIMENT_RUNS_RESPONSE.restored_run_ids + ], + "conflictedRunIds": [ + str(run_id) + for run_id in RESTORE_EXPERIMENT_RUNS_RESPONSE.conflicted_run_ids + ], +} + + +def test_experiment_schema(): + data = ExperimentSchema().load(EXPERIMENT_BODY) + assert data == EXPERIMENT + + +def test_experiment_schema_nullable_deleted_at(): + body = EXPERIMENT_BODY.copy() + body["deletedAt"] = None + data = ExperimentSchema().load(body) + assert data.deleted_at is None + + +def test_experiment_schema_invalid(): + with pytest.raises(ValidationError): + ExperimentSchema().load({}) + + +def test_experiment_run_schema(): + data = ExperimentRunSchema().load(EXPERIMENT_RUN_BODY) + assert data == EXPERIMENT_RUN + + +@pytest.mark.parametrize( + "data_key, field", + [ + ("parentRunId", "parent_run_id"), + ("endedAt", "ended_at"), + ("deletedAt", "deleted_at"), + ], +) +def test_experiment_run_schema_nullable_field(data_key, field): + body = EXPERIMENT_RUN_BODY.copy() + del body[data_key] + data = ExperimentRunSchema().load(body) + assert getattr(data, field) is None + + +@pytest.mark.parametrize("parent_run_id", [None, PARENT_RUN_ID]) +@pytest.mark.parametrize( + "started_at", + [RUN_STARTED_AT, RUN_STARTED_AT_NO_TIMEZONE], + ids=["timezone", "no timezone"], +) +@pytest.mark.parametrize("artifact_location", [None, "faculty:project-id"]) +@pytest.mark.parametrize("tags", [[], [{"key": "key", "value": "value"}]]) +def test_create_run_schema(parent_run_id, started_at, artifact_location, tags): + data = CreateRunSchema().dump( + { + "name": EXPERIMENT_RUN_NAME, + "parent_run_id": parent_run_id, + "started_at": started_at, + "artifact_location": artifact_location, + "tags": tags, + } + ) + assert data == { + "name": EXPERIMENT_RUN_NAME, + "parentRunId": None if parent_run_id is None else str(parent_run_id), + "startedAt": RUN_STARTED_AT_STRING_PYTHON, + "artifactLocation": artifact_location, + "tags": tags, + } + + +def test_metric_schema(): + data = MetricSchema().load(METRIC_BODY) + assert data == METRIC + + +def test_param_schema(): + data = ParamSchema().load(PARAM_BODY) + assert data == PARAM + + +def test_tag_schema(): + data = TagSchema().load(TAG_BODY) + assert data == TAG + + +def test_tag_schema_dump(): + data = TagSchema().dump(TAG_BODY) + assert data == TAG_BODY + + +def test_experiment_run_data_schema(): + data = ExperimentRunDataSchema().dump( + {"metrics": [METRIC], "params": [PARAM], "tags": [TAG]} + ) + assert data == EXPERIMENT_RUN_DATA_BODY + + +def test_experiment_run_data_schema_empty(): + data = ExperimentRunDataSchema().dump({}) + assert data == {} + + +def test_experiment_run_data_schema_multiple(): + data = ExperimentRunDataSchema().dump({"tags": [TAG, OTHER_TAG]}) + assert data == {"tags": [TAG_BODY, OTHER_TAG_BODY]} + + +PROJECT_ID_FILTER = ProjectIdFilter(ComparisonOperator.EQUAL_TO, PROJECT_ID) +PROJECT_ID_FILTER_BODY = { + "by": "projectId", + "operator": "eq", + "value": str(PROJECT_ID), +} + +TAG_FILTER = TagFilter("tag-key", ComparisonOperator.EQUAL_TO, "tag-value") +TAG_FILTER_BODY = { + "by": "tag", + "key": "tag-key", + "operator": "eq", + "value": "tag-value", +} + + +DEFINED_TEST_CASES = [ + (ComparisonOperator.DEFINED, False, "defined", False), + (ComparisonOperator.DEFINED, True, "defined", True), + (ComparisonOperator.DEFINED, 0, "defined", False), + (ComparisonOperator.DEFINED, 1, "defined", True), +] + + +def discrete_test_cases(value, expected): + return DEFINED_TEST_CASES + [ + (ComparisonOperator.EQUAL_TO, value, "eq", expected), + (ComparisonOperator.NOT_EQUAL_TO, value, "ne", expected), + ] + + +def continuous_test_cases(value, expected): + return discrete_test_cases(value, expected) + [ + (ComparisonOperator.GREATER_THAN, value, "gt", expected), + (ComparisonOperator.GREATER_THAN_OR_EQUAL_TO, value, "ge", expected), + (ComparisonOperator.LESS_THAN, value, "lt", expected), + (ComparisonOperator.LESS_THAN_OR_EQUAL_TO, value, "le", expected), + ] + + +@pytest.mark.parametrize( + "operator, value, expected_operator, expected_value", + discrete_test_cases(PROJECT_ID, str(PROJECT_ID)), +) +def test_filter_schema_project_id( + operator, value, expected_operator, expected_value +): + filter = ProjectIdFilter(operator, value) + data = FilterSchema().dump(filter) + assert data == { + "by": "projectId", + "operator": expected_operator, + "value": expected_value, + } + + +@pytest.mark.parametrize( + "operator, value, expected_operator, expected_value", + discrete_test_cases(EXPERIMENT_ID, EXPERIMENT_ID), +) +def test_filter_schema_experiment_id( + operator, value, expected_operator, expected_value +): + filter = ExperimentIdFilter(operator, value) + data = FilterSchema().dump(filter) + assert data == { + "by": "experimentId", + "operator": expected_operator, + "value": expected_value, + } + + +@pytest.mark.parametrize( + "operator, value, expected_operator, expected_value", + discrete_test_cases(RUN_ID, str(RUN_ID)), +) +def test_filter_schema_run_id( + operator, value, expected_operator, expected_value +): + filter = RunIdFilter(operator, value) + data = FilterSchema().dump(filter) + assert data == { + "by": "runId", + "operator": expected_operator, + "value": expected_value, + } + + +@pytest.mark.parametrize( + "operator, value, expected_operator, expected_value", + continuous_test_cases(DELETED_AT, DELETED_AT_STRING_PYTHON), +) +def test_filter_schema_deleted_at( + operator, value, expected_operator, expected_value +): + filter = DeletedAtFilter(operator, value) + data = FilterSchema().dump(filter) + assert data == { + "by": "deletedAt", + "operator": expected_operator, + "value": expected_value, + } + + +@pytest.mark.parametrize( + "operator, value, expected_operator, expected_value", + discrete_test_cases("tag-value", "tag-value"), +) +def test_filter_schema_tag(operator, value, expected_operator, expected_value): + filter = TagFilter("tag-key", operator, value) + data = FilterSchema().dump(filter) + assert data == { + "by": "tag", + "key": "tag-key", + "operator": expected_operator, + "value": expected_value, + } + + +@pytest.mark.parametrize( + "operator, value, expected_operator, expected_value", + discrete_test_cases("param-value", "param-value") + + continuous_test_cases(123.2, 123.2), +) +def test_filter_schema_param( + operator, value, expected_operator, expected_value +): + filter = ParamFilter("param-key", operator, value) + data = FilterSchema().dump(filter) + assert data == { + "by": "param", + "key": "param-key", + "operator": expected_operator, + "value": expected_value, + } + + +@pytest.mark.parametrize( + "operator, value, expected_operator, expected_value", + continuous_test_cases(45.6, 45.6), +) +def test_filter_schema_metric( + operator, value, expected_operator, expected_value +): + filter = MetricFilter("metric-key", operator, value) + data = FilterSchema().dump(filter) + assert data == { + "by": "metric", + "key": "metric-key", + "operator": expected_operator, + "value": expected_value, + } + + +@pytest.mark.parametrize( + "filter_type", + [ProjectIdFilter, ExperimentIdFilter, RunIdFilter, DeletedAtFilter], +) +def test_filter_schema_invalid_value_no_key(filter_type): + filter = filter_type(ComparisonOperator.EQUAL_TO, "invalid") + with pytest.raises(ValidationError): + FilterSchema().dump(filter) + + +@pytest.mark.parametrize( + "filter_type, value", [(ParamFilter, None), (MetricFilter, "invalid")] +) +def test_filter_schema_invalid_value_with_key(filter_type, value): + filter = filter_type("key", ComparisonOperator.EQUAL_TO, value) + with pytest.raises(ValidationError): + FilterSchema().dump(filter) + + +@pytest.mark.parametrize( + "filter_type, value", + [ + (ProjectIdFilter, PROJECT_ID), + (ExperimentIdFilter, EXPERIMENT_ID), + (RunIdFilter, RUN_ID), + ], +) +@pytest.mark.parametrize( + "operator", + [ + ComparisonOperator.GREATER_THAN, + ComparisonOperator.GREATER_THAN_OR_EQUAL_TO, + ComparisonOperator.LESS_THAN, + ComparisonOperator.LESS_THAN_OR_EQUAL_TO, + ], +) +def test_filter_schema_invalid_operator_no_key(filter_type, value, operator): + filter = filter_type(operator, value) + with pytest.raises(ValidationError, match="Not a discrete operator"): + FilterSchema().dump(filter) + + +@pytest.mark.parametrize( + "filter_type, value", + [(TagFilter, "tag-value"), (ParamFilter, "param-string-value")], +) +@pytest.mark.parametrize( + "operator", + [ + ComparisonOperator.GREATER_THAN, + ComparisonOperator.GREATER_THAN_OR_EQUAL_TO, + ComparisonOperator.LESS_THAN, + ComparisonOperator.LESS_THAN_OR_EQUAL_TO, + ], +) +def test_filter_schema_invalid_operator_with_key(filter_type, value, operator): + filter = filter_type("key", operator, value) + with pytest.raises(ValidationError, match="Not a discrete operator"): + FilterSchema().dump(filter) + + +@pytest.mark.parametrize( + "operator, expected_operator", + [(LogicalOperator.AND, "and"), (LogicalOperator.OR, "or")], +) +def test_filter_schema_compound(operator, expected_operator): + filter = CompoundFilter(operator, [PROJECT_ID_FILTER, TAG_FILTER]) + data = FilterSchema().dump(filter) + assert data == { + "operator": expected_operator, + "conditions": [PROJECT_ID_FILTER_BODY, TAG_FILTER_BODY], + } + + +def test_filter_schema_nested(): + filter = CompoundFilter( + LogicalOperator.AND, + [ + CompoundFilter( + LogicalOperator.AND, [PROJECT_ID_FILTER, TAG_FILTER] + ), + CompoundFilter( + LogicalOperator.OR, [TAG_FILTER, PROJECT_ID_FILTER] + ), + ], + ) + data = FilterSchema().dump(filter) + assert data == { + "operator": "and", + "conditions": [ + { + "operator": "and", + "conditions": [PROJECT_ID_FILTER_BODY, TAG_FILTER_BODY], + }, + { + "operator": "or", + "conditions": [TAG_FILTER_BODY, PROJECT_ID_FILTER_BODY], + }, + ], + } + + +@pytest.mark.parametrize( + "sort_type, by", + [ + (StartedAtSort, "startedAt"), + (RunNumberSort, "runNumber"), + (DurationSort, "duration"), + ], +) +@pytest.mark.parametrize( + "order, expected_order", [(SortOrder.ASC, "asc"), (SortOrder.DESC, "desc")] +) +def test_sort_schema_no_key(sort_type, by, order, expected_order): + sort = sort_type(order) + data = SortSchema().dump(sort) + assert data == {"by": by, "order": expected_order} + + +@pytest.mark.parametrize( + "sort_type, by", + [(TagSort, "tag"), (ParamSort, "param"), (MetricSort, "metric")], +) +@pytest.mark.parametrize( + "order, expected_order", [(SortOrder.ASC, "asc"), (SortOrder.DESC, "desc")] +) +def test_sort_schema_with_key(sort_type, by, order, expected_order): + sort = sort_type("sort-key", order) + data = SortSchema().dump(sort) + assert data == {"by": by, "key": "sort-key", "order": expected_order} + + +def test_run_query_schema(mocker): + mocker.patch.object(FilterSchema, "dump") + mocker.patch.object(SortSchema, "dump") + mocker.patch.object(PageSchema, "dump") + + filter = mocker.Mock() + sorts = [mocker.Mock(), mocker.Mock()] + page = mocker.Mock() + + run_query = RunQuery(filter, sorts, page) + data = RunQuerySchema().dump(run_query) + + assert data == { + "filter": FilterSchema.dump.return_value, + "sort": [SortSchema.dump.return_value, SortSchema.dump.return_value], + "page": PageSchema.dump.return_value, + } + + +def test_run_query_schema_defaults(): + run_query = RunQuery(None, None, None) + data = RunQuerySchema().dump(run_query) + assert data == {"filter": None, "sort": None, "page": None} + + +def test_list_runs_schema(mocker): + data = ListExperimentRunsResponseSchema().load( + LIST_EXPERIMENT_RUNS_RESPONSE_BODY + ) + assert data == LIST_EXPERIMENT_RUNS_RESPONSE + + +def test_page_schema_load(): + data = PageSchema().load(PAGE_BODY) + assert data == PAGE + + +def test_page_schema_dump(): + data = PageSchema().dump(PAGE) + assert data == PAGE_BODY + + +def test_pagination_schema(): + data = PaginationSchema().load(PAGINATION_BODY) + assert data == PAGINATION + + +@pytest.mark.parametrize("field", ["previous", "next"]) +def test_pagination_schema_nullable_field(field): + body = PAGINATION_BODY.copy() + del body[field] + data = PaginationSchema().load(body) + assert getattr(data, field) is None + + +def test_delete_experiment_runs_response_schema(mocker): + data = DeleteExperimentRunsResponseSchema().load( + DELETE_EXPERIMENT_RUNS_RESPONSE_BODY + ) + assert data == DELETE_EXPERIMENT_RUNS_RESPONSE + + +def test_delete_experiment_runs_response_schema_invalid(mocker): + with pytest.raises(ValidationError): + DeleteExperimentRunsResponseSchema().load({}) + + +def test_restore_experiment_runs_response_schema(mocker): + data = RestoreExperimentRunsResponseSchema().load( + RESTORE_EXPERIMENT_RUNS_RESPONSE_BODY + ) + assert data == RESTORE_EXPERIMENT_RUNS_RESPONSE + + +def test_restore_experiment_runs_response_schema_invalid(mocker): + with pytest.raises(ValidationError): + RestoreExperimentRunsResponseSchema().load({}) + + +def test_metric_history_schema(): + data = MetricHistorySchema().load(METRIC_HISTORY_BODY) + assert data == METRIC_HISTORY + + +def test_metric_history_schema_invalid(): + with pytest.raises(ValidationError): + MetricHistorySchema().load({}) diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py deleted file mode 100644 index 2441829b..00000000 --- a/tests/clients/test_experiment.py +++ /dev/null @@ -1,1000 +0,0 @@ -# Copyright 2018-2019 Faculty Science Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from datetime import datetime -from uuid import uuid4 - -import pytest -from marshmallow import ValidationError -from pytz import UTC - -from faculty.clients.base import Conflict -from faculty.clients.experiment import ( - CreateRunSchema, - DeleteExperimentRunsResponse, - DeleteExperimentRunsResponseSchema, - Experiment, - ExperimentClient, - ExperimentNameConflict, - ExperimentDeleted, - ExperimentRun, - ExperimentRunDataSchema, - ExperimentRunSchema, - ExperimentRunStatus, - ExperimentSchema, - LifecycleStage, - ListExperimentRunsResponse, - ListExperimentRunsResponseSchema, - Metric, - MetricDataPoint, - MetricSchema, - MetricHistory, - MetricHistorySchema, - Page, - PageSchema, - Pagination, - PaginationSchema, - Param, - ParamConflict, - ParamSchema, - RestoreExperimentRunsResponse, - RestoreExperimentRunsResponseSchema, - Tag, - TagSchema, -) - -PROJECT_ID = uuid4() -EXPERIMENT_ID = 661 -EXPERIMENT_RUN_ID = uuid4() -EXPERIMENT_RUN_NUMBER = 3 -EXPERIMENT_RUN_NAME = "run name" -PARENT_RUN_ID = uuid4() -CREATED_AT = datetime(2018, 3, 10, 11, 32, 6, 247000, tzinfo=UTC) -CREATED_AT_STRING = "2018-03-10T11:32:06.247Z" -LAST_UPDATED_AT = datetime(2018, 3, 10, 11, 32, 30, 172000, tzinfo=UTC) -LAST_UPDATED_AT_STRING = "2018-03-10T11:32:30.172Z" -DELETED_AT = datetime(2018, 3, 10, 11, 37, 42, 482000, tzinfo=UTC) -DELETED_AT_STRING = "2018-03-10T11:37:42.482Z" - -EXPERIMENT = Experiment( - id=EXPERIMENT_ID, - name="experiment name", - description="experiment description", - artifact_location="https://example.com", - created_at=CREATED_AT, - last_updated_at=LAST_UPDATED_AT, - deleted_at=DELETED_AT, -) -EXPERIMENT_BODY = { - "experimentId": EXPERIMENT_ID, - "name": EXPERIMENT.name, - "description": EXPERIMENT.description, - "artifactLocation": EXPERIMENT.artifact_location, - "createdAt": CREATED_AT_STRING, - "lastUpdatedAt": LAST_UPDATED_AT_STRING, - "deletedAt": DELETED_AT_STRING, -} - -RUN_STARTED_AT = datetime(2018, 3, 10, 11, 39, 12, 110000, tzinfo=UTC) -RUN_STARTED_AT_NO_TIMEZONE = datetime(2018, 3, 10, 11, 39, 12, 110000) -RUN_STARTED_AT_STRING_PYTHON = "2018-03-10T11:39:12.110000+00:00" -RUN_STARTED_AT_STRING_JAVA = "2018-03-10T11:39:12.11Z" -RUN_ENDED_AT = datetime(2018, 3, 10, 11, 39, 15, 110000, tzinfo=UTC) -RUN_ENDED_AT_STRING = "2018-03-10T11:39:15.11Z" - -TAG = Tag(key="tag-key", value="tag-value") -TAG_BODY = {"key": "tag-key", "value": "tag-value"} - -OTHER_TAG = Tag(key="other-tag-key", value="other-tag-value") -OTHER_TAG_BODY = {"key": "other-tag-key", "value": "other-tag-value"} - -PARAM = Param(key="param-key", value="param-value") -PARAM_BODY = {"key": "param-key", "value": "param-value"} - -METRIC_KEY = "metric-key" -METRIC = Metric( - key=METRIC_KEY, - value=123.0, - timestamp=datetime(2018, 3, 12, 16, 20, 22, 122000, tzinfo=UTC), - step=0, -) -METRIC_BODY = { - "key": METRIC.key, - "value": METRIC.value, - "timestamp": "2018-03-12T16:20:22.122000+00:00", - "step": METRIC.step, -} - -METRIC_DATA_POINT = MetricDataPoint( - value=METRIC.value, timestamp=METRIC.timestamp, step=METRIC.step -) -METRIC_DATA_POINT_BODY = { - "value": METRIC_BODY["value"], - "timestamp": METRIC_BODY["timestamp"], - "step": METRIC_BODY["step"], -} - -METRIC_HISTORY = MetricHistory( - original_size=1, - subsampled=False, - key=METRIC_KEY, - history=[METRIC_DATA_POINT], -) -METRIC_HISTORY_BODY = { - "originalSize": METRIC_HISTORY.original_size, - "subsampled": METRIC_HISTORY.subsampled, - "key": METRIC_HISTORY.key, - "history": [METRIC_DATA_POINT_BODY], -} - -EXPERIMENT_RUN = ExperimentRun( - id=EXPERIMENT_RUN_ID, - run_number=EXPERIMENT_RUN_NUMBER, - name=EXPERIMENT_RUN_NAME, - parent_run_id=PARENT_RUN_ID, - experiment_id=EXPERIMENT.id, - artifact_location="faculty:", - status=ExperimentRunStatus.RUNNING, - started_at=RUN_STARTED_AT, - ended_at=RUN_ENDED_AT, - deleted_at=DELETED_AT, - tags=[TAG], - params=[PARAM], - metrics=[METRIC], -) -EXPERIMENT_RUN_BODY = { - "experimentId": EXPERIMENT.id, - "runId": str(EXPERIMENT_RUN_ID), - "runNumber": EXPERIMENT_RUN_NUMBER, - "name": EXPERIMENT_RUN_NAME, - "parentRunId": str(PARENT_RUN_ID), - "artifactLocation": "faculty:", - "status": "running", - "startedAt": RUN_STARTED_AT_STRING_JAVA, - "endedAt": RUN_ENDED_AT_STRING, - "deletedAt": DELETED_AT_STRING, - "tags": [TAG_BODY], - "metrics": [METRIC_BODY], - "params": [PARAM_BODY], -} - -EXPERIMENT_RUN_DATA_BODY = { - "metrics": [METRIC_BODY], - "params": [PARAM_BODY], - "tags": [TAG_BODY], -} - - -PAGE = Page(start=3, limit=10) -PAGE_BODY = {"start": PAGE.start, "limit": PAGE.limit} - -PAGINATION = Pagination( - start=20, - size=10, - previous=Page(start=10, limit=10), - next=Page(start=30, limit=10), -) -PAGINATION_BODY = { - "start": PAGINATION.start, - "size": PAGINATION.size, - "previous": { - "start": PAGINATION.previous.start, - "limit": PAGINATION.previous.limit, - }, - "next": {"start": PAGINATION.next.start, "limit": PAGINATION.next.limit}, -} - -LIST_EXPERIMENT_RUNS_RESPONSE = ListExperimentRunsResponse( - runs=[EXPERIMENT_RUN], pagination=PAGINATION -) -LIST_EXPERIMENT_RUNS_RESPONSE_BODY = { - "runs": [EXPERIMENT_RUN_BODY], - "pagination": PAGINATION_BODY, -} - -DELETE_EXPERIMENT_RUNS_RESPONSE = DeleteExperimentRunsResponse( - deleted_run_ids=[uuid4(), uuid4()], conflicted_run_ids=[uuid4(), uuid4()] -) -DELETE_EXPERIMENT_RUNS_RESPONSE_BODY = { - "deletedRunIds": [ - str(run_id) - for run_id in DELETE_EXPERIMENT_RUNS_RESPONSE.deleted_run_ids - ], - "conflictedRunIds": [ - str(run_id) - for run_id in DELETE_EXPERIMENT_RUNS_RESPONSE.conflicted_run_ids - ], -} - -RESTORE_EXPERIMENT_RUNS_RESPONSE = RestoreExperimentRunsResponse( - restored_run_ids=[uuid4(), uuid4()], conflicted_run_ids=[uuid4(), uuid4()] -) -RESTORE_EXPERIMENT_RUNS_RESPONSE_BODY = { - "restoredRunIds": [ - str(run_id) - for run_id in RESTORE_EXPERIMENT_RUNS_RESPONSE.restored_run_ids - ], - "conflictedRunIds": [ - str(run_id) - for run_id in RESTORE_EXPERIMENT_RUNS_RESPONSE.conflicted_run_ids - ], -} - - -def test_experiment_schema(): - data = ExperimentSchema().load(EXPERIMENT_BODY) - assert data == EXPERIMENT - - -def test_experiment_schema_nullable_deleted_at(): - body = EXPERIMENT_BODY.copy() - body["deletedAt"] = None - data = ExperimentSchema().load(body) - assert data.deleted_at is None - - -def test_experiment_schema_invalid(): - with pytest.raises(ValidationError): - ExperimentSchema().load({}) - - -def test_experiment_run_schema(): - data = ExperimentRunSchema().load(EXPERIMENT_RUN_BODY) - assert data == EXPERIMENT_RUN - - -@pytest.mark.parametrize( - "data_key, field", - [ - ("parentRunId", "parent_run_id"), - ("endedAt", "ended_at"), - ("deletedAt", "deleted_at"), - ], -) -def test_experiment_run_schema_nullable_field(data_key, field): - body = EXPERIMENT_RUN_BODY.copy() - del body[data_key] - data = ExperimentRunSchema().load(body) - assert getattr(data, field) is None - - -@pytest.mark.parametrize("parent_run_id", [None, PARENT_RUN_ID]) -@pytest.mark.parametrize( - "started_at", - [RUN_STARTED_AT, RUN_STARTED_AT_NO_TIMEZONE], - ids=["timezone", "no timezone"], -) -@pytest.mark.parametrize("artifact_location", [None, "faculty:project-id"]) -@pytest.mark.parametrize("tags", [[], [{"key": "key", "value": "value"}]]) -def test_create_run_schema(parent_run_id, started_at, artifact_location, tags): - data = CreateRunSchema().dump( - { - "name": EXPERIMENT_RUN_NAME, - "parent_run_id": parent_run_id, - "started_at": started_at, - "artifact_location": artifact_location, - "tags": tags, - } - ) - assert data == { - "name": EXPERIMENT_RUN_NAME, - "parentRunId": None if parent_run_id is None else str(parent_run_id), - "startedAt": RUN_STARTED_AT_STRING_PYTHON, - "artifactLocation": artifact_location, - "tags": tags, - } - - -def test_metric_schema(): - data = MetricSchema().load(METRIC_BODY) - assert data == METRIC - - -def test_param_schema(): - data = ParamSchema().load(PARAM_BODY) - assert data == PARAM - - -def test_tag_schema(): - data = TagSchema().load(TAG_BODY) - assert data == TAG - - -def test_tag_schema_dump(): - data = TagSchema().dump(TAG_BODY) - assert data == TAG_BODY - - -def test_experiment_run_data_schema(): - data = ExperimentRunDataSchema().dump( - {"metrics": [METRIC], "params": [PARAM], "tags": [TAG]} - ) - assert data == EXPERIMENT_RUN_DATA_BODY - - -def test_experiment_run_data_schema_empty(): - data = ExperimentRunDataSchema().dump({}) - assert data == {} - - -def test_experiment_run_data_schema_multiple(): - data = ExperimentRunDataSchema().dump({"tags": [TAG, OTHER_TAG]}) - assert data == {"tags": [TAG_BODY, OTHER_TAG_BODY]} - - -@pytest.mark.parametrize("description", [None, "experiment description"]) -@pytest.mark.parametrize("artifact_location", [None, "s3://mybucket"]) -def test_experiment_client_create(mocker, description, artifact_location): - mocker.patch.object(ExperimentClient, "_post", return_value=EXPERIMENT) - schema_mock = mocker.patch("faculty.clients.experiment.ExperimentSchema") - - client = ExperimentClient(mocker.Mock()) - returned_experiment = client.create( - PROJECT_ID, "experiment name", description, artifact_location - ) - assert returned_experiment == EXPERIMENT - - schema_mock.assert_called_once_with() - ExperimentClient._post.assert_called_once_with( - "/project/{}/experiment".format(PROJECT_ID), - schema_mock.return_value, - json={ - "name": "experiment name", - "description": description, - "artifactLocation": artifact_location, - }, - ) - - -def test_experiment_client_create_name_conflict(mocker): - error_code = "experiment_name_conflict" - exception = Conflict(mocker.Mock(), mocker.Mock(), error_code) - mocker.patch.object(ExperimentClient, "_post", side_effect=exception) - - client = ExperimentClient(mocker.Mock()) - with pytest.raises( - ExperimentNameConflict, match="name 'experiment name' already exists" - ): - client.create(PROJECT_ID, "experiment name") - - -def test_experiment_client_get(mocker): - mocker.patch.object(ExperimentClient, "_get", return_value=EXPERIMENT) - schema_mock = mocker.patch("faculty.clients.experiment.ExperimentSchema") - - client = ExperimentClient(mocker.Mock()) - returned_experiment = client.get(PROJECT_ID, EXPERIMENT.id) - assert returned_experiment == EXPERIMENT - - schema_mock.assert_called_once_with() - ExperimentClient._get.assert_called_once_with( - "/project/{}/experiment/{}".format(PROJECT_ID, EXPERIMENT.id), - schema_mock.return_value, - ) - - -def test_experiment_client_list(mocker): - mocker.patch.object(ExperimentClient, "_get", return_value=[EXPERIMENT]) - schema_mock = mocker.patch("faculty.clients.experiment.ExperimentSchema") - - client = ExperimentClient(mocker.Mock()) - assert client.list(PROJECT_ID) == [EXPERIMENT] - - schema_mock.assert_called_once_with(many=True) - ExperimentClient._get.assert_called_once_with( - "/project/{}/experiment".format(PROJECT_ID), - schema_mock.return_value, - params={}, - ) - - -def test_experiment_client_list_lifecycle_filter(mocker): - mocker.patch.object(ExperimentClient, "_get", return_value=[EXPERIMENT]) - schema_mock = mocker.patch("faculty.clients.experiment.ExperimentSchema") - - client = ExperimentClient(mocker.Mock()) - returned_experiments = client.list( - PROJECT_ID, lifecycle_stage=LifecycleStage.ACTIVE - ) - assert returned_experiments == [EXPERIMENT] - - schema_mock.assert_called_once_with(many=True) - ExperimentClient._get.assert_called_once_with( - "/project/{}/experiment".format(PROJECT_ID), - schema_mock.return_value, - params={"lifecycleStage": "active"}, - ) - - -@pytest.mark.parametrize("name", [None, "new name"]) -@pytest.mark.parametrize("description", [None, "new description"]) -def test_experiment_client_update(mocker, name, description): - mocker.patch.object(ExperimentClient, "_patch_raw") - - client = ExperimentClient(mocker.Mock()) - client.update( - PROJECT_ID, EXPERIMENT_ID, name=name, description=description - ) - - ExperimentClient._patch_raw.assert_called_once_with( - "/project/{}/experiment/{}".format(PROJECT_ID, EXPERIMENT_ID), - json={"name": name, "description": description}, - ) - - -def test_experiment_client_update_name_conflict(mocker): - error_code = "experiment_name_conflict" - exception = Conflict(mocker.Mock(), mocker.Mock(), error_code) - mocker.patch.object(ExperimentClient, "_patch_raw", side_effect=exception) - - client = ExperimentClient(mocker.Mock()) - with pytest.raises( - ExperimentNameConflict, match="name 'new name' already exists" - ): - client.update(PROJECT_ID, EXPERIMENT_ID, name="new name") - - -def test_delete(mocker): - mocker.patch.object(ExperimentClient, "_delete_raw") - - client = ExperimentClient(mocker.Mock()) - client.delete(PROJECT_ID, EXPERIMENT_ID) - - ExperimentClient._delete_raw.assert_called_once_with( - "/project/{}/experiment/{}".format(PROJECT_ID, EXPERIMENT_ID) - ) - - -def test_restore(mocker): - mocker.patch.object(ExperimentClient, "_put_raw") - - client = ExperimentClient(mocker.Mock()) - client.restore(PROJECT_ID, EXPERIMENT_ID) - - ExperimentClient._put_raw.assert_called_once_with( - "/project/{}/experiment/{}/restore".format(PROJECT_ID, EXPERIMENT_ID) - ) - - -def test_experiment_create_run(mocker): - mocker.patch.object(ExperimentClient, "_post", return_value=EXPERIMENT_RUN) - request_schema_mock = mocker.patch( - "faculty.clients.experiment.CreateRunSchema" - ) - dump_mock = request_schema_mock.return_value.dump - response_schema_mock = mocker.patch( - "faculty.clients.experiment.ExperimentRunSchema" - ) - started_at = mocker.Mock() - artifact_location = mocker.Mock() - - client = ExperimentClient(mocker.Mock()) - returned_run = client.create_run( - PROJECT_ID, - EXPERIMENT_ID, - EXPERIMENT_RUN_NAME, - started_at, - PARENT_RUN_ID, - artifact_location=artifact_location, - ) - assert returned_run == EXPERIMENT_RUN - - request_schema_mock.assert_called_once_with() - dump_mock.assert_called_once_with( - { - "name": EXPERIMENT_RUN_NAME, - "parent_run_id": PARENT_RUN_ID, - "started_at": started_at, - "artifact_location": artifact_location, - "tags": [], - } - ) - response_schema_mock.assert_called_once_with() - ExperimentClient._post.assert_called_once_with( - "/project/{}/experiment/{}/run".format(PROJECT_ID, EXPERIMENT_ID), - response_schema_mock.return_value, - json=dump_mock.return_value, - ) - - -def test_experiment_create_run_experiment_deleted_conflict(mocker): - message = "experiment deleted" - error_code = "experiment_deleted" - response_mock = mocker.Mock() - response_mock.json.return_value = {"experimentId": 42} - exception = Conflict(response_mock, message, error_code) - - mocker.patch.object(ExperimentClient, "_post", side_effect=exception) - started_at = mocker.Mock() - artifact_location = mocker.Mock() - - client = ExperimentClient(mocker.Mock()) - with pytest.raises(ExperimentDeleted, match=message): - client.create_run( - PROJECT_ID, - EXPERIMENT_ID, - EXPERIMENT_RUN_NAME, - started_at, - PARENT_RUN_ID, - artifact_location=artifact_location, - ) - - -def test_experiment_client_get_run(mocker): - mocker.patch.object(ExperimentClient, "_get", return_value=EXPERIMENT_RUN) - schema_mock = mocker.patch( - "faculty.clients.experiment.ExperimentRunSchema" - ) - - client = ExperimentClient(mocker.Mock()) - returned_run = client.get_run(PROJECT_ID, EXPERIMENT_RUN_ID) - assert returned_run == EXPERIMENT_RUN - - schema_mock.assert_called_once_with() - ExperimentClient._get.assert_called_once_with( - "/project/{}/run/{}".format(PROJECT_ID, EXPERIMENT_RUN_ID), - schema_mock.return_value, - ) - - -def test_list_runs_schema(mocker): - data = ListExperimentRunsResponseSchema().load( - LIST_EXPERIMENT_RUNS_RESPONSE_BODY - ) - assert data == LIST_EXPERIMENT_RUNS_RESPONSE - - -def test_page_schema(): - data = PageSchema().load(PAGE_BODY) - assert data == PAGE - - -def test_pagination_schema(): - data = PaginationSchema().load(PAGINATION_BODY) - assert data == PAGINATION - - -@pytest.mark.parametrize("field", ["previous", "next"]) -def test_pagination_schema_nullable_field(field): - body = PAGINATION_BODY.copy() - del body[field] - data = PaginationSchema().load(body) - assert getattr(data, field) is None - - -def test_delete_experiment_runs_response_schema(mocker): - data = DeleteExperimentRunsResponseSchema().load( - DELETE_EXPERIMENT_RUNS_RESPONSE_BODY - ) - assert data == DELETE_EXPERIMENT_RUNS_RESPONSE - - -def test_delete_experiment_runs_response_schema_invalid(mocker): - with pytest.raises(ValidationError): - DeleteExperimentRunsResponseSchema().load({}) - - -def test_restore_experiment_runs_response_schema(mocker): - data = RestoreExperimentRunsResponseSchema().load( - RESTORE_EXPERIMENT_RUNS_RESPONSE_BODY - ) - assert data == RESTORE_EXPERIMENT_RUNS_RESPONSE - - -def test_restore_experiment_runs_response_schema_invalid(mocker): - with pytest.raises(ValidationError): - RestoreExperimentRunsResponseSchema().load({}) - - -def test_experiment_client_list_runs_all(mocker): - mocker.patch.object( - ExperimentClient, "_get", return_value=LIST_EXPERIMENT_RUNS_RESPONSE - ) - schema_mock = mocker.patch( - "faculty.clients.experiment.ListExperimentRunsResponseSchema" - ) - - client = ExperimentClient(mocker.Mock()) - list_result = client.list_runs(PROJECT_ID) - assert list_result == LIST_EXPERIMENT_RUNS_RESPONSE - - schema_mock.assert_called_once_with() - ExperimentClient._get.assert_called_once_with( - "/project/{}/run".format(PROJECT_ID), - schema_mock.return_value, - params=[], - ) - - -def test_experiment_client_list_runs_experiments_filter(mocker): - mocker.patch.object( - ExperimentClient, "_get", return_value=LIST_EXPERIMENT_RUNS_RESPONSE - ) - schema_mock = mocker.patch( - "faculty.clients.experiment.ListExperimentRunsResponseSchema" - ) - - client = ExperimentClient(mocker.Mock()) - list_result = client.list_runs(PROJECT_ID, experiment_ids=[123, 456]) - assert list_result == LIST_EXPERIMENT_RUNS_RESPONSE - schema_mock.assert_called_once_with() - ExperimentClient._get.assert_called_once_with( - "/project/{}/run".format(PROJECT_ID), - schema_mock.return_value, - params=[("experimentId", 123), ("experimentId", 456)], - ) - - -def test_experiment_client_list_runs_experiments_filter_empty(mocker): - client = ExperimentClient(mocker.Mock()) - list_result = client.list_runs(PROJECT_ID, experiment_ids=[]) - - assert list_result == ListExperimentRunsResponse( - runs=[], - pagination=Pagination(start=0, size=0, previous=None, next=None), - ) - - -def test_experiment_client_list_runs_page(mocker): - mocker.patch.object( - ExperimentClient, "_get", return_value=LIST_EXPERIMENT_RUNS_RESPONSE - ) - schema_mock = mocker.patch( - "faculty.clients.experiment.ListExperimentRunsResponseSchema" - ) - - client = ExperimentClient(mocker.Mock()) - list_result = client.list_runs(PROJECT_ID, start=20, limit=10) - assert list_result == LIST_EXPERIMENT_RUNS_RESPONSE - - schema_mock.assert_called_once_with() - ExperimentClient._get.assert_called_once_with( - "/project/{}/run".format(PROJECT_ID), - schema_mock.return_value, - params=[("start", 20), ("limit", 10)], - ) - - -def test_log_run_data(mocker): - mocker.patch.object(ExperimentClient, "_patch_raw") - run_data_schema_mock = mocker.patch( - "faculty.clients.experiment.ExperimentRunDataSchema" - ) - run_data_dump_mock = run_data_schema_mock.return_value.dump - - client = ExperimentClient(mocker.Mock()) - client.log_run_data( - PROJECT_ID, - EXPERIMENT_RUN_ID, - metrics=[METRIC], - params=[PARAM], - tags=[TAG], - ) - - run_data_schema_mock.assert_called_once_with() - run_data_dump_mock.assert_called_once_with( - {"metrics": [METRIC], "params": [PARAM], "tags": [TAG]} - ) - ExperimentClient._patch_raw.assert_called_once_with( - "/project/{}/run/{}/data".format(PROJECT_ID, EXPERIMENT_RUN_ID), - json=run_data_dump_mock.return_value, - ) - - -def test_log_run_data_param_conflict(mocker): - message = "bad params" - error_code = "conflicting_params" - response_mock = mocker.Mock() - response_mock.json.return_value = {"parameterKeys": ["bad-key"]} - exception = Conflict(response_mock, message, error_code) - - mocker.patch.object(ExperimentClient, "_patch_raw", side_effect=exception) - - client = ExperimentClient(mocker.Mock()) - - with pytest.raises(ParamConflict, match=message): - client.log_run_data(PROJECT_ID, EXPERIMENT_RUN_ID, params=[PARAM]) - - -def test_log_run_data_other_conflict(mocker): - response_mock = mocker.Mock() - exception = Conflict(response_mock, "", "") - - mocker.patch.object(ExperimentClient, "_patch_raw", side_effect=exception) - client = ExperimentClient(mocker.Mock()) - - with pytest.raises(Conflict): - client.log_run_data(PROJECT_ID, EXPERIMENT_RUN_ID, params=[PARAM]) - - -def test_log_run_data_empty(mocker): - mocker.patch.object(ExperimentClient, "_patch_raw") - - client = ExperimentClient(mocker.Mock()) - - client.log_run_data(PROJECT_ID, EXPERIMENT_RUN_ID) - ExperimentClient._patch_raw.assert_not_called() - - -def test_update_run_info(mocker): - mocker.patch.object( - ExperimentClient, "_patch", return_value=EXPERIMENT_RUN - ) - run_schema_mock = mocker.patch( - "faculty.clients.experiment.ExperimentRunSchema" - ) - run_info_schema_mock = mocker.patch( - "faculty.clients.experiment.ExperimentRunInfoSchema" - ) - run_info_dump_mock = run_info_schema_mock.return_value.dump - - client = ExperimentClient(mocker.Mock()) - returned_run = client.update_run_info( - PROJECT_ID, - EXPERIMENT_RUN_ID, - EXPERIMENT_RUN.status, - EXPERIMENT_RUN.ended_at, - ) - assert returned_run == EXPERIMENT_RUN - - run_schema_mock.assert_called_once_with() - run_info_schema_mock.assert_called_once_with() - run_info_dump_mock.assert_called_once_with( - {"status": EXPERIMENT_RUN.status, "ended_at": EXPERIMENT_RUN.ended_at} - ) - ExperimentClient._patch.assert_called_once_with( - "/project/{}/run/{}/info".format(PROJECT_ID, EXPERIMENT_RUN_ID), - run_schema_mock.return_value, - json=run_info_dump_mock.return_value, - ) - - -def test_update_run_info_status_only(mocker): - mocker.patch.object( - ExperimentClient, "_patch", return_value=EXPERIMENT_RUN - ) - run_schema_mock = mocker.patch( - "faculty.clients.experiment.ExperimentRunSchema" - ) - run_info_schema_mock = mocker.patch( - "faculty.clients.experiment.ExperimentRunInfoSchema" - ) - run_info_dump_mock = run_info_schema_mock.return_value.dump - - client = ExperimentClient(mocker.Mock()) - returned_run = client.update_run_info( - PROJECT_ID, EXPERIMENT_RUN_ID, status=EXPERIMENT_RUN.status - ) - assert returned_run == EXPERIMENT_RUN - - run_schema_mock.assert_called_once_with() - run_info_schema_mock.assert_called_once_with() - run_info_dump_mock.assert_called_once_with( - {"status": EXPERIMENT_RUN.status, "ended_at": None} - ) - ExperimentClient._patch.assert_called_once_with( - "/project/{}/run/{}/info".format(PROJECT_ID, EXPERIMENT_RUN_ID), - run_schema_mock.return_value, - json=run_info_dump_mock.return_value, - ) - - -def test_update_run_info_ended_at_only(mocker): - mocker.patch.object( - ExperimentClient, "_patch", return_value=EXPERIMENT_RUN - ) - run_schema_mock = mocker.patch( - "faculty.clients.experiment.ExperimentRunSchema" - ) - run_info_schema_mock = mocker.patch( - "faculty.clients.experiment.ExperimentRunInfoSchema" - ) - run_info_dump_mock = run_info_schema_mock.return_value.dump - - client = ExperimentClient(mocker.Mock()) - returned_run = client.update_run_info( - PROJECT_ID, EXPERIMENT_RUN_ID, ended_at=EXPERIMENT_RUN.ended_at - ) - assert returned_run == EXPERIMENT_RUN - - run_schema_mock.assert_called_once_with() - run_info_schema_mock.assert_called_once_with() - run_info_dump_mock.assert_called_once_with( - {"status": None, "ended_at": EXPERIMENT_RUN.ended_at} - ) - ExperimentClient._patch.assert_called_once_with( - "/project/{}/run/{}/info".format(PROJECT_ID, EXPERIMENT_RUN_ID), - run_schema_mock.return_value, - json=run_info_dump_mock.return_value, - ) - - -def test_update_run_info_empty(mocker): - mocker.patch.object( - ExperimentClient, "_patch", return_value=EXPERIMENT_RUN - ) - run_schema_mock = mocker.patch( - "faculty.clients.experiment.ExperimentRunSchema" - ) - run_info_schema_mock = mocker.patch( - "faculty.clients.experiment.ExperimentRunInfoSchema" - ) - run_info_dump_mock = run_info_schema_mock.return_value.dump - - client = ExperimentClient(mocker.Mock()) - returned_run = client.update_run_info(PROJECT_ID, EXPERIMENT_RUN_ID) - assert returned_run == EXPERIMENT_RUN - - run_schema_mock.assert_called_once_with() - run_info_schema_mock.assert_called_once_with() - run_info_dump_mock.assert_called_once_with( - {"status": None, "ended_at": None} - ) - ExperimentClient._patch.assert_called_once_with( - "/project/{}/run/{}/info".format(PROJECT_ID, EXPERIMENT_RUN_ID), - run_schema_mock.return_value, - json=run_info_dump_mock.return_value, - ) - - -def test_metric_history_schema(): - data = MetricHistorySchema().load(METRIC_HISTORY_BODY) - assert data == METRIC_HISTORY - - -def test_metric_history_schema_invalid(): - with pytest.raises(ValidationError): - MetricHistorySchema().load({}) - - -def test_get_metric_history(mocker): - mocker.patch.object(ExperimentClient, "_get", return_value=METRIC_HISTORY) - metric_history_schema_mock = mocker.patch( - "faculty.clients.experiment.MetricHistorySchema" - ) - - client = ExperimentClient(mocker.Mock()) - - returned_metric_history = client.get_metric_history( - PROJECT_ID, EXPERIMENT_RUN_ID, METRIC_KEY - ) - assert returned_metric_history == [METRIC] - - metric_history_schema_mock.assert_called_once_with() - - ExperimentClient._get.assert_called_once_with( - "/project/{}/run/{}/metric/{}/history".format( - PROJECT_ID, EXPERIMENT_RUN_ID, METRIC_KEY - ), - metric_history_schema_mock.return_value, - ) - - -def test_delete_runs(mocker): - mocker.patch.object( - ExperimentClient, "_post", return_value=DELETE_EXPERIMENT_RUNS_RESPONSE - ) - schema_mock = mocker.patch( - "faculty.clients.experiment.DeleteExperimentRunsResponseSchema" - ) - run_ids = [uuid4(), uuid4()] - - client = ExperimentClient(mocker.Mock()) - assert ( - client.delete_runs(PROJECT_ID, run_ids) - == DELETE_EXPERIMENT_RUNS_RESPONSE - ) - - expected_payload = { - "filter": { - "operator": "or", - "conditions": [ - {"by": "runId", "operator": "eq", "value": str(run_ids[0])}, - {"by": "runId", "operator": "eq", "value": str(run_ids[1])}, - ], - } - } - - ExperimentClient._post.assert_called_once_with( - "/project/{}/run/delete/query".format(PROJECT_ID), - schema_mock.return_value, - json=expected_payload, - ) - - -def test_delete_runs_no_run_ids(mocker): - mocker.patch.object(ExperimentClient, "_post") - schema_mock = mocker.patch( - "faculty.clients.experiment.DeleteExperimentRunsResponseSchema" - ) - - client = ExperimentClient(mocker.Mock()) - client.delete_runs(PROJECT_ID) - - ExperimentClient._post.assert_called_once_with( - "/project/{}/run/delete/query".format(PROJECT_ID), - schema_mock.return_value, - json={}, - ) - - -def test_delete_runs_empty_list(mocker): - client = ExperimentClient(mocker.Mock()) - - assert client.delete_runs( - PROJECT_ID, run_ids=[] - ) == DeleteExperimentRunsResponse( - deleted_run_ids=[], conflicted_run_ids=[] - ) - - -def test_restore_runs(mocker): - mocker.patch.object( - ExperimentClient, - "_post", - return_value=RESTORE_EXPERIMENT_RUNS_RESPONSE, - ) - schema_mock = mocker.patch( - "faculty.clients.experiment.RestoreExperimentRunsResponseSchema" - ) - run_ids = [uuid4(), uuid4()] - - client = ExperimentClient(mocker.Mock()) - assert ( - client.restore_runs(PROJECT_ID, run_ids) - == RESTORE_EXPERIMENT_RUNS_RESPONSE - ) - - expected_payload = { - "filter": { - "operator": "or", - "conditions": [ - {"by": "runId", "operator": "eq", "value": str(run_ids[0])}, - {"by": "runId", "operator": "eq", "value": str(run_ids[1])}, - ], - } - } - - ExperimentClient._post.assert_called_once_with( - "/project/{}/run/restore/query".format(PROJECT_ID), - schema_mock.return_value, - json=expected_payload, - ) - - -def test_restore_runs_no_run_ids(mocker): - mocker.patch.object(ExperimentClient, "_post") - schema_mock = mocker.patch( - "faculty.clients.experiment.RestoreExperimentRunsResponseSchema" - ) - - client = ExperimentClient(mocker.Mock()) - client.restore_runs(PROJECT_ID) - - ExperimentClient._post.assert_called_once_with( - "/project/{}/run/restore/query".format(PROJECT_ID), - schema_mock.return_value, - json={}, - ) - - -def test_restore_runs_empty_list(mocker): - client = ExperimentClient(mocker.Mock()) - - assert client.restore_runs( - PROJECT_ID, run_ids=[] - ) == RestoreExperimentRunsResponse( - restored_run_ids=[], conflicted_run_ids=[] - )