diff --git a/CHANGELOG.rst b/CHANGELOG.rst index f95946de4..c1406cbdd 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -6,6 +6,14 @@ History All release highlights of this project will be documented in this file. +4.4.36 - June 05, 2025 +______________________ + +**Updated** + + - ``SAClient.get_project_steps`` and ``SAClient.get_project_steps`` now support keypoint workflows, enabling structured step configuration with class IDs, attributes, and step connections. + - ``SAClient.list_users`` now returns user-specific permission states for paused, allow_orchestrate, allow_run_explore, and allow_view_sdk_token. + 4.4.35 - May 2, 2025 ____________________ diff --git a/docs/source/api_reference/api_project.rst b/docs/source/api_reference/api_project.rst index 52be1c3b6..858030333 100644 --- a/docs/source/api_reference/api_project.rst +++ b/docs/source/api_reference/api_project.rst @@ -24,6 +24,6 @@ Projects .. automethod:: superannotate.SAClient.add_contributors_to_project .. automethod:: superannotate.SAClient.get_project_settings .. automethod:: superannotate.SAClient.set_project_default_image_quality_in_editor -.. automethod:: superannotate.SAClient.set_project_steps .. automethod:: superannotate.SAClient.get_project_steps +.. automethod:: superannotate.SAClient.set_project_steps .. automethod:: superannotate.SAClient.get_component_config diff --git a/src/superannotate/__init__.py b/src/superannotate/__init__.py index bb86da3c8..6beddbe4f 100644 --- a/src/superannotate/__init__.py +++ b/src/superannotate/__init__.py @@ -3,7 +3,7 @@ import sys -__version__ = "4.4.35" +__version__ = "4.4.36" os.environ.update({"sa_version": __version__}) diff --git a/src/superannotate/lib/app/interface/sdk_interface.py b/src/superannotate/lib/app/interface/sdk_interface.py index e7d028884..8bcc3d924 100644 --- a/src/superannotate/lib/app/interface/sdk_interface.py +++ b/src/superannotate/lib/app/interface/sdk_interface.py @@ -74,7 +74,6 @@ from lib.core.entities.work_managament import WMUserTypeEnum from lib.core.jsx_conditions import EmptyQuery - logger = logging.getLogger("sa") NotEmptyStr = constr(strict=True, min_length=1) @@ -1493,10 +1492,11 @@ def get_project_steps(self, project: Union[str, dict]): :param project: project name or metadata :type project: str or dict - :return: project steps - :rtype: list of dicts + :return: A list of step dictionaries, + or a dictionary containing both steps and their connections (for Keypoint workflows). + :rtype: list of dicts or dict - Response Example: + Response Example for General Annotation Project: :: [ @@ -1515,6 +1515,34 @@ def get_project_steps(self, project: Union[str, dict]): } ] + Response Example for Keypoint Annotation Project: + :: + + { + "steps": [ + { + "step": 1, + "className": "Left Shoulder", + "class_id": "1", + "attribute": [ + { + "attribute": { + "id": 123, + "group_id": 12 + } + } + ] + }, + { + "step": 2, + "class_id": "2", + "className": "Right Shoulder", + } + ], + "connections": [ + [1, 2] + ] + } """ project_name, _ = extract_project_folder(project) project = self.controller.get_project(project_name) @@ -2511,7 +2539,12 @@ def download_export( if response.errors: raise AppException(response.errors) - def set_project_steps(self, project: Union[NotEmptyStr, dict], steps: List[dict]): + def set_project_steps( + self, + project: Union[NotEmptyStr, dict], + steps: List[dict], + connections: List[List[int]] = None, + ): """Sets project's steps. :param project: project name or metadata @@ -2520,7 +2553,11 @@ def set_project_steps(self, project: Union[NotEmptyStr, dict], steps: List[dict] :param steps: new workflow list of dicts :type steps: list of dicts - Request Example: + :param connections: Defines connections between keypoint annotation steps. + Each inner list specifies a pair of step IDs indicating a connection. + :type connections: list of list + + Request Example for General Annotation Project: :: sa.set_project_steps( @@ -2541,10 +2578,40 @@ def set_project_steps(self, project: Union[NotEmptyStr, dict], steps: List[dict] } ] ) + + Request Example for Keypoint Annotation Project: + :: + + sa.set_project_steps( + project="Pose Estimation Project", + steps=[ + { + "step": 1, + "class_id": 12, + "attribute": [ + { + "attribute": { + "id": 123, + "group_id": 12 + } + } + ] + }, + { + "step": 2, + "class_id": 13 + } + ], + connections=[ + [1, 2] + ] + ) """ project_name, _ = extract_project_folder(project) project = self.controller.get_project(project_name) - response = self.controller.projects.set_steps(project, steps=steps) + response = self.controller.projects.set_steps( + project, steps=steps, connections=connections + ) if response.errors: raise AppException(response.errors) diff --git a/src/superannotate/lib/core/__init__.py b/src/superannotate/lib/core/__init__.py index 8203ca390..a27bf5ad4 100644 --- a/src/superannotate/lib/core/__init__.py +++ b/src/superannotate/lib/core/__init__.py @@ -10,6 +10,7 @@ from lib.core.enums import ImageQuality from lib.core.enums import ProjectStatus from lib.core.enums import ProjectType +from lib.core.enums import StepsType from lib.core.enums import TrainingStatus from lib.core.enums import UploadState from lib.core.enums import UserRole @@ -186,6 +187,7 @@ def setup_logging(level=DEFAULT_LOGGING_LEVEL, file_path=LOG_FILE_LOCATION): FolderStatus, ProjectStatus, ProjectType, + StepsType, UserRole, UploadState, TrainingStatus, diff --git a/src/superannotate/lib/core/entities/work_managament.py b/src/superannotate/lib/core/entities/work_managament.py index 40aff7cb5..8e23f38b9 100644 --- a/src/superannotate/lib/core/entities/work_managament.py +++ b/src/superannotate/lib/core/entities/work_managament.py @@ -133,6 +133,7 @@ class WMProjectUserEntity(TimedBaseModel): email: Optional[str] state: Optional[WMUserStateEnum] custom_fields: Optional[dict] = Field(dict(), alias="customField") + permissions: Optional[dict] class Config: extra = Extra.ignore diff --git a/src/superannotate/lib/core/enums.py b/src/superannotate/lib/core/enums.py index cb2631bb6..f780d877a 100644 --- a/src/superannotate/lib/core/enums.py +++ b/src/superannotate/lib/core/enums.py @@ -117,6 +117,12 @@ def images(self): return self.VECTOR.value, self.PIXEL.value, self.TILED.value +class StepsType(Enum): + INITIAL = 1 + BASIC = 2 + KEYPOINT = 3 + + class UserRole(BaseTitledEnum): CONTRIBUTOR = "Contributor", 4 ADMIN = "Admin", 7 diff --git a/src/superannotate/lib/core/serviceproviders.py b/src/superannotate/lib/core/serviceproviders.py index 1db527fe7..9ffba499f 100644 --- a/src/superannotate/lib/core/serviceproviders.py +++ b/src/superannotate/lib/core/serviceproviders.py @@ -264,10 +264,18 @@ def set_settings( def list_steps(self, project: entities.ProjectEntity): raise NotImplementedError + @abstractmethod + def list_keypoint_steps(self, project: entities.ProjectEntity): + raise NotImplementedError + @abstractmethod def set_step(self, project: entities.ProjectEntity, step: entities.StepEntity): raise NotImplementedError + @abstractmethod + def set_keypoint_steps(self, project: entities.ProjectEntity, steps, connections): + raise NotImplementedError + @abstractmethod def set_steps(self, project: entities.ProjectEntity, steps: list): raise NotImplementedError diff --git a/src/superannotate/lib/core/usecases/annotations.py b/src/superannotate/lib/core/usecases/annotations.py index c6628f152..854f1ebd0 100644 --- a/src/superannotate/lib/core/usecases/annotations.py +++ b/src/superannotate/lib/core/usecases/annotations.py @@ -107,10 +107,10 @@ def log_report( class ItemToUpload(BaseModel): item: BaseItemEntity - annotation_json: Optional[dict] - path: Optional[str] - file_size: Optional[int] - mask: Optional[io.BytesIO] + annotation_json: Optional[dict] = None + path: Optional[str] = None + file_size: Optional[int] = None + mask: Optional[io.BytesIO] = None class Config: arbitrary_types_allowed = True @@ -282,7 +282,7 @@ def validate_project_type(self): raise AppException("Unsupported project type.") def _validate_json(self, json_data: dict) -> list: - if self._project.type >= constants.ProjectType.PIXEL.value: + if self._project.type >= int(constants.ProjectType.PIXEL): return [] use_case = ValidateAnnotationUseCase( reporter=self.reporter, @@ -2101,16 +2101,16 @@ def execute(self): if categorization_enabled: item_id_category_map = {} for item_name in uploaded_annotations: - category = ( - name_annotation_map[item_name]["metadata"] - .get("item_category", {}) - .get("value") + category = name_annotation_map[item_name]["metadata"].get( + "item_category", None ) if category: item_id_category_map[name_item_map[item_name].id] = category - self._attach_categories( - folder_id=folder.id, item_id_category_map=item_id_category_map - ) + if item_id_category_map: + self._attach_categories( + folder_id=folder.id, + item_id_category_map=item_id_category_map, + ) workflow = self._service_provider.work_management.get_workflow( self._project.workflow_id ) @@ -2149,7 +2149,7 @@ def _attach_categories(self, folder_id: int, item_id_category_map: Dict[int, str ) response.raise_for_status() categories = response.data - self._category_name_to_id_map = {c.name: c.id for c in categories} + self._category_name_to_id_map = {c.value: c.id for c in categories} for item_id in list(item_id_category_map.keys()): category_name = item_id_category_map[item_id] if category_name not in self._category_name_to_id_map: diff --git a/src/superannotate/lib/core/usecases/projects.py b/src/superannotate/lib/core/usecases/projects.py index 23a913ced..3877c5e56 100644 --- a/src/superannotate/lib/core/usecases/projects.py +++ b/src/superannotate/lib/core/usecases/projects.py @@ -1,9 +1,10 @@ import decimal import logging +import math from collections import defaultdict from typing import List -import lib.core as constances +import lib.core as constants from lib.core.conditions import Condition from lib.core.conditions import CONDITION_EQ as EQ from lib.core.entities import ContributorEntity @@ -21,7 +22,6 @@ from lib.core.usecases.base import BaseUseCase from lib.core.usecases.base import BaseUserBasedUseCase - logger = logging.getLogger("sa") @@ -228,12 +228,12 @@ def __init__( def validate_settings(self): for setting in self._project.settings[:]: - if setting.attribute not in constances.PROJECT_SETTINGS_VALID_ATTRIBUTES: + if setting.attribute not in constants.PROJECT_SETTINGS_VALID_ATTRIBUTES: self._project.settings.remove(setting) if setting.attribute == "ImageQuality" and isinstance(setting.value, str): - setting.value = constances.ImageQuality(setting.value).value + setting.value = constants.ImageQuality(setting.value).value elif setting.attribute == "FrameRate": - if not self._project.type == constances.ProjectType.VIDEO.value: + if not self._project.type == constants.ProjectType.VIDEO.value: raise AppValidationException( "FrameRate is available only for Video projects" ) @@ -263,14 +263,14 @@ def validate_project_name(self): if ( len( set(self._project.name).intersection( - constances.SPECIAL_CHARACTERS_IN_PROJECT_FOLDER_NAMES + constants.SPECIAL_CHARACTERS_IN_PROJECT_FOLDER_NAMES ) ) > 0 ): self._project.name = "".join( "_" - if char in constances.SPECIAL_CHARACTERS_IN_PROJECT_FOLDER_NAMES + if char in constants.SPECIAL_CHARACTERS_IN_PROJECT_FOLDER_NAMES else char for char in self._project.name ) @@ -291,7 +291,7 @@ def validate_project_name(self): def execute(self): if self.is_valid(): # new projects can only have the status of NotStarted - self._project.status = constances.ProjectStatus.NotStarted.value + self._project.status = constants.ProjectStatus.NotStarted.value response = self._service_provider.projects.create(self._project) if not response.ok: self._response.errors = response.error @@ -326,7 +326,7 @@ def execute(self): data["classes"] = self._project.classes logger.info( f"Created project {entity.name} (ID {entity.id}) " - f"with type {constances.ProjectType(self._response.data.type).name}." + f"with type {constants.ProjectType(self._response.data.type).name}." ) return self._response @@ -368,12 +368,12 @@ def __init__( def validate_settings(self): for setting in self._project.settings[:]: - if setting.attribute not in constances.PROJECT_SETTINGS_VALID_ATTRIBUTES: + if setting.attribute not in constants.PROJECT_SETTINGS_VALID_ATTRIBUTES: self._project.settings.remove(setting) if setting.attribute == "ImageQuality" and isinstance(setting.value, str): - setting.value = constances.ImageQuality(setting.value).value + setting.value = constants.ImageQuality(setting.value).value elif setting.attribute == "FrameRate": - if not self._project.type == constances.ProjectType.VIDEO.value: + if not self._project.type == constants.ProjectType.VIDEO.value: raise AppValidationException( "FrameRate is available only for Video projects" ) @@ -404,14 +404,14 @@ def validate_project_name(self): if ( len( set(self._project.name).intersection( - constances.SPECIAL_CHARACTERS_IN_PROJECT_FOLDER_NAMES + constants.SPECIAL_CHARACTERS_IN_PROJECT_FOLDER_NAMES ) ) > 0 ): self._project.name = "".join( "_" - if char in constances.SPECIAL_CHARACTERS_IN_PROJECT_FOLDER_NAMES + if char in constants.SPECIAL_CHARACTERS_IN_PROJECT_FOLDER_NAMES else char for char in self._project.name ) @@ -484,26 +484,39 @@ def __init__(self, project: ProjectEntity, service_provider: BaseServiceProvider self._service_provider = service_provider def validate_project_type(self): - if self._project.type in constances.LIMITED_FUNCTIONS: + if self._project.type in constants.LIMITED_FUNCTIONS: raise AppValidationException( - constances.LIMITED_FUNCTIONS[self._project.type] + constants.LIMITED_FUNCTIONS[self._project.type] ) def execute(self): if self.is_valid(): - data = [] - steps = self._service_provider.projects.list_steps(self._project).data - for step in steps: - step_data = step.dict() - annotation_classes = self._service_provider.annotation_classes.list( - Condition("project_id", self._project.id, EQ) - ).data - for annotation_class in annotation_classes: - if annotation_class.id == step.class_id: - step_data["className"] = annotation_class.name - break - data.append(step_data) - self._response.data = data + project_settings = self._service_provider.projects.list_settings( + project=self._project + ).data + step_setting = next( + (i for i in project_settings if i.attribute == "WorkflowType"), None + ) + if step_setting.value == constants.StepsType.KEYPOINT.value: + self._response.data = ( + self._service_provider.projects.list_keypoint_steps( + self._project + ).data["steps"] + ) + else: + data = [] + steps = self._service_provider.projects.list_steps(self._project).data + for step in steps: + step_data = step.dict() + annotation_classes = self._service_provider.annotation_classes.list( + Condition("project_id", self._project.id, EQ) + ).data + for annotation_class in annotation_classes: + if annotation_class.id == step.class_id: + step_data["className"] = annotation_class.name + break + data.append(step_data) + self._response.data = data return self._response @@ -524,7 +537,7 @@ def validate_image_quality(self): if setting["attribute"].lower() == "imagequality" and isinstance( setting["value"], str ): - setting["value"] = constances.ImageQuality(setting["value"]).value + setting["value"] = constants.ImageQuality(setting["value"]).value return def validate_project_type(self): @@ -532,11 +545,11 @@ def validate_project_type(self): if attribute.get( "attribute", "" ) == "ImageQuality" and self._project.type in [ - constances.ProjectType.VIDEO.value, - constances.ProjectType.DOCUMENT.value, + constants.ProjectType.VIDEO.value, + constants.ProjectType.DOCUMENT.value, ]: raise AppValidationException( - constances.DEPRICATED_DOCUMENT_VIDEO_MESSAGE + constants.DEPRICATED_DOCUMENT_VIDEO_MESSAGE ) def execute(self): @@ -552,7 +565,7 @@ def execute(self): for new_setting in self._to_update: if ( new_setting["attribute"] - in constances.PROJECT_SETTINGS_VALID_ATTRIBUTES + in constants.PROJECT_SETTINGS_VALID_ATTRIBUTES ): new_settings_to_update.append( SettingEntity( @@ -579,80 +592,167 @@ def __init__( service_provider: BaseServiceProvider, steps: list, project: ProjectEntity, + connections: List[List[int]] = None, ): super().__init__() self._service_provider = service_provider self._steps = steps + self._connections = connections self._project = project def validate_project_type(self): - if self._project.type in constances.LIMITED_FUNCTIONS: + if self._project.type in constants.LIMITED_FUNCTIONS: raise AppValidationException( - constances.LIMITED_FUNCTIONS[self._project.type] + constants.LIMITED_FUNCTIONS[self._project.type] ) + def validate_connections(self): + if not self._connections: + return + if not all([len(i) == 2 for i in self._connections]): + raise AppException("Invalid connections.") + steps_count = len(self._steps) + if len(self._connections) > max( + math.factorial(steps_count) / (2 * math.factorial(steps_count - 2)), 1 + ): + raise AppValidationException( + "Invalid connections: duplicates in a connection group." + ) + + possible_connections = set(range(1, len(self._steps) + 1)) + for connection_group in self._connections: + if len(set(connection_group)) != len(connection_group): + raise AppValidationException( + "Invalid connections: duplicates in a connection group." + ) + if not set(connection_group).issubset(possible_connections): + raise AppValidationException( + "Invalid connections: index out of allowed range." + ) + + def set_basic_steps(self, annotation_classes): + annotation_classes_map = {} + annotations_classes_attributes_map = {} + for annotation_class in annotation_classes: + annotation_classes_map[annotation_class.name] = annotation_class.id + for attribute_group in annotation_class.attribute_groups: + for attribute in attribute_group.attributes: + annotations_classes_attributes_map[ + f"{annotation_class.name}__{attribute_group.name}__{attribute.name}" + ] = attribute.id + + for step in [step for step in self._steps if "className" in step]: + if step.get("id"): + del step["id"] + step["class_id"] = annotation_classes_map.get(step["className"], None) + if not step["class_id"]: + raise AppException("Annotation class not found.") + self._service_provider.projects.set_steps( + project=self._project, + steps=self._steps, + ) + existing_steps = self._service_provider.projects.list_steps(self._project).data + existing_steps_map = {} + for steps in existing_steps: + existing_steps_map[steps.step] = steps.id + + req_data = [] + for step in self._steps: + annotation_class_name = step["className"] + for attribute in step["attribute"]: + attribute_name = attribute["attribute"]["name"] + attribute_group_name = attribute["attribute"]["attribute_group"]["name"] + if not annotations_classes_attributes_map.get( + f"{annotation_class_name}__{attribute_group_name}__{attribute_name}", + None, + ): + raise AppException( + f"Attribute group name or attribute name not found {attribute_group_name}." + ) + + if not existing_steps_map.get(step["step"], None): + raise AppException("Couldn't find step in steps") + req_data.append( + { + "workflow_id": existing_steps_map[step["step"]], + "attribute_id": annotations_classes_attributes_map[ + f"{annotation_class_name}__{attribute_group_name}__{attribute_name}" + ], + } + ) + self._service_provider.projects.set_project_step_attributes( + project=self._project, + attributes=req_data, + ) + + @staticmethod + def _validate_keypoint_steps(annotation_classes, steps): + class_group_attrs_map = {} + for annotation_class in annotation_classes: + class_group_attrs_map[annotation_class.id] = dict() + for group in annotation_class.attribute_groups: + class_group_attrs_map[annotation_class.id][group.id] = [] + for attribute in group.attributes: + class_group_attrs_map[annotation_class.id][group.id].append( + attribute.id + ) + for step in steps: + class_id = step.get("class_id", None) + if not class_id or class_id not in class_group_attrs_map: + raise AppException("Annotation class not found.") + attributes = step.get("attribute", None) + if not attributes: + continue + for attr in attributes: + try: + _id, group_id = attr["attribute"].get("id", None), attr[ + "attribute" + ].get("group_id", None) + assert _id in class_group_attrs_map[class_id][group_id] + except (KeyError, AssertionError): + raise AppException("Invalid steps provided.") + + def set_keypoint_steps(self, annotation_classes, steps, connections): + self._validate_keypoint_steps(annotation_classes, steps) + for i in range(1, len(self._steps) + 1): + step = self._steps[i - 1] + step["id"] = i + if "attribute" not in step: + step["attribute"] = [] + self._service_provider.projects.set_keypoint_steps( + project=self._project, + steps=steps, + connections=connections, + ) + def execute(self): if self.is_valid(): + annotation_classes = self._service_provider.annotation_classes.list( Condition("project_id", self._project.id, EQ) ).data - annotation_classes_map = {} - annotations_classes_attributes_map = {} - for annotation_class in annotation_classes: - annotation_classes_map[annotation_class.name] = annotation_class.id - for attribute_group in annotation_class.attribute_groups: - for attribute in attribute_group.attributes: - annotations_classes_attributes_map[ - f"{annotation_class.name}__{attribute_group.name}__{attribute.name}" - ] = attribute.id - - for step in [step for step in self._steps if "className" in step]: - if step.get("id"): - del step["id"] - step["class_id"] = annotation_classes_map.get(step["className"], None) - if not step["class_id"]: - raise AppException("Annotation class not found.") - self._service_provider.projects.set_steps( - project=self._project, - steps=self._steps, - ) - existing_steps = self._service_provider.projects.list_steps( - self._project - ).data - existing_steps_map = {} - for steps in existing_steps: - existing_steps_map[steps.step] = steps.id - - req_data = [] - for step in self._steps: - annotation_class_name = step["className"] - for attribute in step["attribute"]: - attribute_name = attribute["attribute"]["name"] - attribute_group_name = attribute["attribute"]["attribute_group"][ - "name" - ] - if not annotations_classes_attributes_map.get( - f"{annotation_class_name}__{attribute_group_name}__{attribute_name}", - None, - ): - raise AppException( - f"Attribute group name or attribute name not found {attribute_group_name}." - ) - if not existing_steps_map.get(step["step"], None): - raise AppException("Couldn't find step in steps") - req_data.append( - { - "workflow_id": existing_steps_map[step["step"]], - "attribute_id": annotations_classes_attributes_map[ - f"{annotation_class_name}__{attribute_group_name}__{attribute_name}" - ], - } - ) - self._service_provider.projects.set_project_step_attributes( - project=self._project, - attributes=req_data, + project_settings = self._service_provider.projects.list_settings( + project=self._project + ).data + step_setting = next( + (i for i in project_settings if i.attribute == "WorkflowType"), None ) + if self._connections is None and step_setting.value in [ + constants.StepsType.INITIAL.value, + constants.StepsType.BASIC.value, + ]: + self.set_basic_steps(annotation_classes) + elif self._connections is not None and step_setting.value in [ + constants.StepsType.INITIAL.value, + constants.StepsType.KEYPOINT.value, + ]: + self.set_keypoint_steps( + annotation_classes, self._steps, self._connections + ) + else: + raise AppException("Can't update steps type.") + return self._response @@ -744,11 +844,11 @@ def execute(self): team_users = set() project_users = {user.user_id for user in self._project.users} for user in self._team.users: - if user.user_role == constances.UserRole.CONTRIBUTOR.value: + if user.user_role == constants.UserRole.CONTRIBUTOR.value: team_users.add(user.email) # collecting pending team users which is not admin for user in self._team.pending_invitations: - if user["user_role"] == constances.UserRole.CONTRIBUTOR.value: + if user["user_role"] == constants.UserRole.CONTRIBUTOR.value: team_users.add(user["email"]) # collecting pending project users which is not admin for user in self._project.unverified_users: @@ -831,9 +931,9 @@ def execute(self): response = self._service_provider.invite_contributors( team_id=self._team.id, # REMINDER UserRole.VIEWER is the contributor for the teams - team_role=constances.UserRole.ADMIN.value + team_role=constants.UserRole.ADMIN.value if self._set_admin - else constances.UserRole.CONTRIBUTOR.value, + else constants.UserRole.CONTRIBUTOR.value, emails=to_add, ) invited, failed = ( diff --git a/src/superannotate/lib/infrastructure/controller.py b/src/superannotate/lib/infrastructure/controller.py index ed8f0a238..48e6cea63 100644 --- a/src/superannotate/lib/infrastructure/controller.py +++ b/src/superannotate/lib/infrastructure/controller.py @@ -489,10 +489,13 @@ def list_steps(self, project: ProjectEntity): ) return use_case.execute() - def set_steps(self, project: ProjectEntity, steps: List): + def set_steps( + self, project: ProjectEntity, steps: List, connections: List[List[int]] = None + ): use_case = usecases.SetStepsUseCase( service_provider=self.service_provider, steps=steps, + connections=connections, project=project, ) return use_case.execute() diff --git a/src/superannotate/lib/infrastructure/services/project.py b/src/superannotate/lib/infrastructure/services/project.py index b84d2f67a..112d7ebde 100644 --- a/src/superannotate/lib/infrastructure/services/project.py +++ b/src/superannotate/lib/infrastructure/services/project.py @@ -14,6 +14,8 @@ class ProjectService(BaseProjectService): URL_GET = "project/{}" URL_SETTINGS = "project/{}/settings" URL_STEPS = "project/{}/workflow" + URL_KEYPOINT_STEPS = "api/v1/project/{}/downloadSteps" + URL_SET_KEYPOINT_STEPS = "api/v1/project/{}/uploadSteps" URL_SHARE = "api/v1/project/{}/share/bulk" URL_SHARE_PROJECT = "project/{}/share" URL_STEP_ATTRIBUTE = "project/{}/workflow_attribute" @@ -104,6 +106,9 @@ def list_steps(self, project: entities.ProjectEntity): self.URL_STEPS.format(project.id), item_type=entities.StepEntity ) + def list_keypoint_steps(self, project: entities.ProjectEntity): + return self.client.request(self.URL_KEYPOINT_STEPS.format(project.id), "get") + def set_step(self, project: entities.ProjectEntity, step: entities.StepEntity): return self.client.request( self.URL_STEPS.format(project.id), @@ -111,6 +116,18 @@ def set_step(self, project: entities.ProjectEntity, step: entities.StepEntity): data={"steps": [step]}, ) + def set_keypoint_steps(self, project: entities.ProjectEntity, steps, connections): + return self.client.request( + self.URL_SET_KEYPOINT_STEPS.format(project.id), + "post", + data={ + "steps": { + "steps": steps, + "connections": connections if connections else [], + } + }, + ) + # TODO check def set_steps(self, project: entities.ProjectEntity, steps: list): return self.client.request( diff --git a/tests/integration/annotations/test_upload_annotations.py b/tests/integration/annotations/test_upload_annotations.py index 181ace079..dd442a8e0 100644 --- a/tests/integration/annotations/test_upload_annotations.py +++ b/tests/integration/annotations/test_upload_annotations.py @@ -1,7 +1,6 @@ import json import os import tempfile -import time from pathlib import Path from src.superannotate import AppException @@ -160,7 +159,8 @@ def setUp(self, *args, **kwargs): ], ) project = sa.controller.get_project(self.PROJECT_NAME) - time.sleep(4) + # todo check + # time.sleep(4) with open(self.EDITOR_TEMPLATE_PATH) as f: res = sa.controller.service_provider.projects.attach_editor_template( sa.controller.team, project, template=json.load(f) @@ -267,6 +267,5 @@ def test_download_annotations(self): assert len(downloaded_data) == len( annotations ), "Mismatch in annotation count" - assert ( - downloaded_data == annotations - ), "Downloaded annotations do not match uploaded annotations" + for a in downloaded_data: + assert a in annotations, "Mismatch in annotation count" diff --git a/tests/integration/items/test_attach_items.py b/tests/integration/items/test_attach_items.py index 3a8332eb7..60956aee1 100644 --- a/tests/integration/items/test_attach_items.py +++ b/tests/integration/items/test_attach_items.py @@ -106,9 +106,6 @@ def test_long_names_limitation_pass(self): } ) sa.attach_items(self.PROJECT_NAME, csv_json) - import time - - time.sleep(4) items = sa.list_items(self.PROJECT_NAME, name__in=[i["name"] for i in csv_json]) assert {i["name"] for i in items} == {i["name"] for i in csv_json} diff --git a/tests/integration/items/test_item_context.py b/tests/integration/items/test_item_context.py index 3a50e2d40..641acf6b4 100644 --- a/tests/integration/items/test_item_context.py +++ b/tests/integration/items/test_item_context.py @@ -1,6 +1,5 @@ import json import os -import time from pathlib import Path from src.superannotate import FileChangedError @@ -32,7 +31,8 @@ def setUp(self, *args, **kwargs): ) team = sa.controller.team project = sa.controller.get_project(self.PROJECT_NAME) - time.sleep(10) + # todo check + # time.sleep(10) with open(self.EDITOR_TEMPLATE_PATH) as f: res = sa.controller.service_provider.projects.attach_editor_template( team, project, template=json.load(f) @@ -65,12 +65,12 @@ def _base_test(self, path, item): def test_overwrite_false(self): # test root by folder name self._attach_item(self.PROJECT_NAME, "dummy") - time.sleep(2) + # time.sleep(2) self._base_test(self.PROJECT_NAME, "dummy") folder = sa.create_folder(self.PROJECT_NAME, folder_name="folder") # test from folder by project and folder names - time.sleep(2) + # time.sleep(2) path = f"{self.PROJECT_NAME}/folder" self._attach_item(path, "dummy") self._base_test(path, "dummy") @@ -107,7 +107,7 @@ def setUp(self, *args, **kwargs): ) team = sa.controller.team project = sa.controller.get_project(self.PROJECT_NAME) - time.sleep(10) + # time.sleep(10) with open(self.EDITOR_TEMPLATE_PATH) as f: res = sa.controller.service_provider.projects.attach_editor_template( team, project, template=json.load(f) diff --git a/tests/integration/items/test_list_items.py b/tests/integration/items/test_list_items.py index f0a80fb79..625f1573b 100644 --- a/tests/integration/items/test_list_items.py +++ b/tests/integration/items/test_list_items.py @@ -2,7 +2,6 @@ import os import random import string -import time from pathlib import Path from src.superannotate import AppException @@ -92,7 +91,6 @@ def setUp(self, *args, **kwargs): ], ) project = sa.controller.get_project(self.PROJECT_NAME) - time.sleep(10) with open(self.EDITOR_TEMPLATE_PATH) as f: res = sa.controller.service_provider.projects.attach_editor_template( sa.controller.team, project, template=json.load(f) diff --git a/tests/integration/steps/a.json b/tests/integration/steps/a.json new file mode 100644 index 000000000..0fb9263c6 --- /dev/null +++ b/tests/integration/steps/a.json @@ -0,0 +1,72 @@ +[ + { + "steps": [ + { + "class_id": 5619764, + "tool": 4, + "attribute": [ + { + "id": 11181258, + "group_id": 5464852 + } + ], + "id": 1 + }, + { + "class_id": 5619763, + "tool": 4, + "attribute": [ + { + "id": 11181255, + "group_id": 5464851 + } + ], + "id": 2 + } + ], + "connections": [ + [ + 1, + 2 + ] + ] + }, + { + "steps": [ + { + "id": 1, + "attribute": [ + { + "attribute": { + "id": 11181248, + "group_id": 5464847 + } + } + ], + "className": "1", + "class_id": 5619759, + "tool": 4 + }, + { + "id": 2, + "attribute": [ + { + "attribute": { + "id": 11181249, + "group_id": 5464848 + } + } + ], + "className": "2", + "class_id": 5619760, + "tool": 4 + } + ], + "connections": [ + [ + 2, + 1 + ] + ] + } +] \ No newline at end of file diff --git a/tests/integration/steps/test_steps.py b/tests/integration/steps/test_steps.py new file mode 100644 index 000000000..6dbddf6dc --- /dev/null +++ b/tests/integration/steps/test_steps.py @@ -0,0 +1,239 @@ +from src.superannotate import AppException +from src.superannotate import SAClient +from tests.integration.base import BaseTestCase + +sa = SAClient() + + +class TestProjectSteps(BaseTestCase): + PROJECT_NAME = "TestProjectSteps" + PROJECT_TYPE = "Vector" + + def setUp(self, *args, **kwargs): + super().setUp() + sa.create_annotation_class( + self.PROJECT_NAME, + "transport", + "#FF0000", + attribute_groups=[ + { + "name": "transport_group", + "attributes": [{"name": "Car"}, {"name": "Track"}, {"name": "Bus"}], + "default_value": "Bus", + } + ], + ) + sa.create_annotation_class( + self.PROJECT_NAME, + "passenger", + "#FF1000", + attribute_groups=[ + { + "name": "passenger_group", + "attributes": [{"name": "white"}, {"name": "black"}], + } + ], + ) + self._classes = sa.search_annotation_classes(self.PROJECT_NAME) + + def test_create_steps(self): + sa.set_project_steps( + self.PROJECT_NAME, + steps=[ + { + "class_id": self._classes[0]["id"], + "attribute": [ + { + "attribute": { + "id": self._classes[0]["attribute_groups"][0][ + "attributes" + ][0]["id"], + "group_id": self._classes[0]["attribute_groups"][0][ + "id" + ], + } + } + ], + }, + { + "class_id": self._classes[1]["id"], + "attribute": [ + { + "attribute": { + "id": self._classes[1]["attribute_groups"][0][ + "attributes" + ][0]["id"], + "group_id": self._classes[1]["attribute_groups"][0][ + "id" + ], + } + } + ], + }, + ], + connections=[[1, 2]], + ) + steps = sa.get_project_steps(self.PROJECT_NAME) + assert len(steps) == 2 + + def test_missing_ids(self): + with self.assertRaisesRegexp(AppException, "Annotation class not found."): + sa.set_project_steps( + self.PROJECT_NAME, + steps=[ + { + "class_id": 1, # invalid class id + "attribute": [ + { + "attribute": { + "id": self._classes[0]["attribute_groups"][0][ + "attributes" + ][0]["id"], + "group_id": self._classes[0]["attribute_groups"][0][ + "id" + ], + } + } + ], + }, + { + "class_id": self._classes[1]["id"], + "attribute": [ + { + "attribute": { + "id": self._classes[1]["attribute_groups"][0][ + "attributes" + ][0]["id"], + "group_id": self._classes[1]["attribute_groups"][0][ + "id" + ], + } + } + ], + }, + ], + connections=[[1, 2]], + ) + + with self.assertRaisesRegexp(AppException, "Invalid steps provided."): + sa.set_project_steps( + self.PROJECT_NAME, + steps=[ + { + "class_id": self._classes[1]["id"], + "attribute": [ + { + "attribute": { + "id": self._classes[0]["attribute_groups"][0][ + "attributes" + ][0]["id"], + "group_id": 1, + } # invalid group id + } + ], + }, + { + "class_id": self._classes[1]["id"], + "attribute": [ + { + "attribute": { + "id": self._classes[1]["attribute_groups"][0][ + "attributes" + ][0]["id"], + "group_id": self._classes[1]["attribute_groups"][0][ + "id" + ], + } + } + ], + }, + ], + connections=[[1, 2]], + ) + + with self.assertRaisesRegexp(AppException, "Invalid steps provided."): + sa.set_project_steps( + self.PROJECT_NAME, + steps=[ + { + "class_id": self._classes[1]["id"], + "attribute": [ + { + "attribute": { + "id": 1, # invalid attr id + "group_id": self._classes[0]["attribute_groups"][0][ + "id" + ], + } + } + ], + }, + { + "class_id": self._classes[1]["id"], + "attribute": [ + { + "attribute": { + "id": self._classes[1]["attribute_groups"][0][ + "attributes" + ][0]["id"], + "group_id": self._classes[1]["attribute_groups"][0][ + "id" + ], + } + } + ], + }, + ], + connections=[[1, 2]], + ) + + def test_create_invalid_connection(self): + args = ( + self.PROJECT_NAME, + [ + { + "class_id": self._classes[0]["id"], + "attribute": [ + { + "attribute": { + "id": self._classes[0]["attribute_groups"][0][ + "attributes" + ][0]["id"], + "group_id": self._classes[0]["attribute_groups"][0][ + "id" + ], + } + } + ], + }, + { + "class_id": self._classes[1]["id"], + "attribute": [ + { + "attribute": { + "id": self._classes[1]["attribute_groups"][0][ + "attributes" + ][0]["id"], + "group_id": self._classes[1]["attribute_groups"][0][ + "id" + ], + } + } + ], + }, + ], + ) + with self.assertRaisesRegexp( + AppException, "Invalid connections: duplicates in a connection group." + ): + sa.set_project_steps( + *args, + connections=[ + [1, 2], + [2, 1], + ] + ) + with self.assertRaisesRegexp( + AppException, "Invalid connections: index out of allowed range." + ): + sa.set_project_steps(*args, connections=[[1, 3]]) diff --git a/tests/integration/work_management/test_pause_resume_user_activity.py b/tests/integration/work_management/test_pause_resume_user_activity.py index 5cbdfeb70..50ae15919 100644 --- a/tests/integration/work_management/test_pause_resume_user_activity.py +++ b/tests/integration/work_management/test_pause_resume_user_activity.py @@ -36,9 +36,6 @@ def test_pause_and_resume_user_activity(self): scapegoat = [ u for u in users if u["role"] == "Contributor" and u["state"] == "Confirmed" ][0] - import pdb - - pdb.set_trace() sa.add_contributors_to_project(self.PROJECT_NAME, [scapegoat["email"]], "QA") with self.assertLogs("sa", level="INFO") as cm: sa.pause_user_activity(pk=scapegoat["email"], projects=[self.PROJECT_NAME]) diff --git a/tests/integration/work_management/test_user_custom_fields.py b/tests/integration/work_management/test_user_custom_fields.py index 293fb79e6..9dc256868 100644 --- a/tests/integration/work_management/test_user_custom_fields.py +++ b/tests/integration/work_management/test_user_custom_fields.py @@ -142,7 +142,6 @@ def test_list_users(self): custom_field_name="SDK_test_date_picker", value=value, ) - time.sleep(1) scapegoat = sa.list_users( include=["custom_fields"], email=scapegoat["email"],