Skip to content

Commit 69ed551

Browse files
authored
Merge pull request #764 from superannotateai/FRIDAY_3554
Added ability to filter by item category
2 parents 761f01e + 04904fb commit 69ed551

File tree

11 files changed

+126
-11
lines changed

11 files changed

+126
-11
lines changed

src/superannotate/lib/app/interface/sdk_interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3421,7 +3421,7 @@ def list_items(
34213421
exclude = {"meta", "annotator_email", "qa_email"}
34223422
if not include_custom_metadata:
34233423
exclude.add("custom_metadata")
3424-
return BaseSerializer.serialize_iterable(res, exclude=exclude)
3424+
return BaseSerializer.serialize_iterable(res, exclude=exclude, by_alias=False)
34253425

34263426
def list_projects(
34273427
self,

src/superannotate/lib/core/entities/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from lib.core.entities.classes import AnnotationClassEntity
55
from lib.core.entities.folder import FolderEntity
66
from lib.core.entities.integrations import IntegrationEntity
7+
from lib.core.entities.items import CategoryEntity
78
from lib.core.entities.items import ClassificationEntity
89
from lib.core.entities.items import DocumentEntity
910
from lib.core.entities.items import ImageEntity
@@ -12,7 +13,6 @@
1213
from lib.core.entities.items import TiledEntity
1314
from lib.core.entities.items import VideoEntity
1415
from lib.core.entities.project import AttachmentEntity
15-
from lib.core.entities.project import CategoryEntity
1616
from lib.core.entities.project import ContributorEntity
1717
from lib.core.entities.project import CustomFieldEntity
1818
from lib.core.entities.project import ProjectEntity

src/superannotate/lib/core/entities/filters.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ class ItemFilters(BaseFilters):
2929
assignments__user_role__in: Optional[List[str]]
3030
assignments__user_role__ne: Optional[str]
3131
assignments__user_role__notin: Optional[List[str]]
32+
categories__value: Optional[str]
33+
categories__value__in: Optional[List[str]]
3234

3335

3436
class ProjectFilters(BaseFilters):

src/superannotate/lib/core/entities/items.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Optional
33

44
from lib.core.entities.base import BaseItemEntity
5-
from lib.core.entities.base import TimedBaseModel
5+
from lib.core.entities.project import TimedBaseModel
66
from lib.core.enums import ApprovalStatus
77
from lib.core.enums import ProjectType
88
from lib.core.pydantic_v1 import Extra
@@ -18,9 +18,17 @@ class Config:
1818
extra = Extra.ignore
1919

2020

21+
class CategoryEntity(TimedBaseModel):
22+
id: int
23+
value: str = Field(None, alias="name")
24+
25+
class Config:
26+
extra = Extra.ignore
27+
28+
2129
class MultiModalItemCategoryEntity(TimedBaseModel):
2230
id: int = Field(None, alias="category_id")
23-
name: str = Field(None, alias="category_name")
31+
value: str = Field(None, alias="category_name")
2432

2533
class Config:
2634
extra = Extra.ignore

src/superannotate/lib/core/entities/project.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,3 @@ def is_system(self):
187187

188188
class Config:
189189
extra = Extra.ignore
190-
191-
192-
class CategoryEntity(BaseModel):
193-
id: Optional[int]
194-
name: Optional[str]

src/superannotate/lib/core/serviceproviders.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,12 @@ class BaseServiceProvider:
698698
def get_role_id(self, project: entities.ProjectEntity, role_name: str) -> int:
699699
raise NotImplementedError
700700

701+
@abstractmethod
702+
def get_category_id(
703+
self, project: entities.ProjectEntity, category_name: str
704+
) -> int:
705+
raise NotImplementedError
706+
701707
@abstractmethod
702708
def get_role_name(self, project: entities.ProjectEntity, role_id: int) -> str:
703709
raise NotImplementedError

src/superannotate/lib/core/usecases/annotations.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2101,8 +2101,10 @@ def execute(self):
21012101
if categorization_enabled:
21022102
item_id_category_map = {}
21032103
for item_name in uploaded_annotations:
2104-
category = name_annotation_map[item_name]["metadata"].get(
2105-
"item_category"
2104+
category = (
2105+
name_annotation_map[item_name]["metadata"]
2106+
.get("item_category", {})
2107+
.get("value")
21062108
)
21072109
if category:
21082110
item_id_category_map[name_item_map[item_name].id] = category

src/superannotate/lib/infrastructure/query_builder.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ def handle(self, filters: Dict[str, Any], query: Query = None) -> Query:
113113
for key, val in filters.items():
114114
_keys = key.split("__")
115115
val = self._handle_special_fields(_keys, val)
116+
if _keys[0] == "categories" and _keys[1] == "value":
117+
_keys[1] = "category_id"
116118
condition, _key = determine_condition_and_key(_keys)
117119
query &= Filter(_key, val, condition)
118120
return super().handle(filters, query)
@@ -147,6 +149,14 @@ def _handle_special_fields(self, keys: List[str], val):
147149
]
148150
else:
149151
val = self._service_provider.get_role_id(self._project, val)
152+
elif keys[0] == "categories" and keys[1] == "value":
153+
if isinstance(val, list):
154+
val = [
155+
self._service_provider.get_category_id(self._project, i)
156+
for i in val
157+
]
158+
else:
159+
val = self._service_provider.get_category_id(self._project, val)
150160
return val
151161

152162

src/superannotate/lib/infrastructure/serviceprovider.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,13 @@ def list_custom_field_names(self, entity: CustomFieldEntityEnum) -> List[str]:
7979
self.client.team_id, entity=entity
8080
)
8181

82+
def get_category_id(
83+
self, project: entities.ProjectEntity, category_name: str
84+
) -> int:
85+
return self._cached_work_management_repository.get_category_id(
86+
project, category_name
87+
)
88+
8289
def get_custom_field_id(
8390
self, field_name: str, entity: CustomFieldEntityEnum
8491
) -> int:

src/superannotate/lib/infrastructure/utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,23 @@ def get(self, key, **kwargs):
138138
return self._K_V_map[key]
139139

140140

141+
class CategoryCache(BaseCachedWorkManagementRepository):
142+
def sync(self, project: ProjectEntity):
143+
response = self.work_management.list_project_categories(project.id)
144+
if not response.ok:
145+
raise AppException(response.error)
146+
categories = response.data
147+
self._K_V_map[project.id] = {
148+
"category_name_id_map": {
149+
category.value: category.id for category in categories
150+
},
151+
"category_id_name_map": {
152+
category.id: category.value for category in categories
153+
},
154+
}
155+
self._update_cache_timestamp(project.id)
156+
157+
141158
class RoleCache(BaseCachedWorkManagementRepository):
142159
def sync(self, project: ProjectEntity):
143160
response = self.work_management.list_workflow_roles(
@@ -221,6 +238,7 @@ def get(self, key, **kwargs):
221238

222239
class CachedWorkManagementRepository:
223240
def __init__(self, ttl_seconds: int, work_management):
241+
self._category_cache = CategoryCache(ttl_seconds, work_management)
224242
self._role_cache = RoleCache(ttl_seconds, work_management)
225243
self._status_cache = StatusCache(ttl_seconds, work_management)
226244
self._project_custom_field_cache = CustomFieldCache(
@@ -236,6 +254,12 @@ def __init__(self, ttl_seconds: int, work_management):
236254
CustomFieldEntityEnum.TEAM,
237255
)
238256

257+
def get_category_id(self, project, category_name: str) -> int:
258+
data = self._category_cache.get(project.id, project=project)
259+
if category_name in data["category_name_id_map"]:
260+
return data["category_name_id_map"][category_name]
261+
raise AppException("Invalid category provided.")
262+
239263
def get_role_id(self, project, role_name: str) -> int:
240264
role_data = self._role_cache.get(project.id, project=project)
241265
if role_name in role_data["role_name_id_map"]:

0 commit comments

Comments
 (0)