55import sys
66from threading import Thread
77from concurrent .futures import ThreadPoolExecutor
8+ from typing import Callable
89
910
1011class DiffgramDatasetIterator :
@@ -15,13 +16,15 @@ class DiffgramDatasetIterator:
1516 file_cache : dict
1617 _internal_file_list : list
1718 current_file_index : int
19+ custom_signer_fn : Callable
1820
1921 def __init__ (self ,
2022 project ,
2123 diffgram_file_id_list ,
2224 validate_ids = True ,
2325 max_size_cache = 1073741824 ,
24- max_num_concurrent_fetches = 25 ):
26+ max_num_concurrent_fetches = 25 ,
27+ custom_signer_fn = None ):
2528 """
2629
2730 :param project (sdk.core.core.Project): A Project object from the Diffgram SDK
@@ -30,6 +33,7 @@ def __init__(self,
3033 self .diffgram_file_id_list = []
3134 self .max_size_cache = 1073741824
3235 self .pool = None
36+ self .custom_signer_fn = custom_signer_fn
3337 self .file_cache = {}
3438 self ._internal_file_list = []
3539 self .current_file_index = 0
@@ -118,16 +122,24 @@ def __validate_file_ids(self):
118122 raise Exception (
119123 'Some file IDs do not belong to the project. Please provide only files from the same project.' )
120124
125+ def set_custom_url_signer (self , signer_fn : Callable ):
126+ self .custom_signer_fn = signer_fn
127+
121128 def get_image_data (self , diffgram_file ):
122129 MAX_RETRIES = 10
123130 image = None
124131 if hasattr (diffgram_file , 'image' ):
125132 for i in range (0 , MAX_RETRIES ):
126133 try :
134+ url = None
127135 if diffgram_file .image :
128136 url = diffgram_file .image .get ('url_signed' )
129- if url :
130- image = imread (diffgram_file .image .get ('url_signed' ))
137+ if diffgram_file .image and self .custom_signer_fn is not None :
138+ blob_path = diffgram_file .image ['url_signed_blob_path' ]
139+ bucket_name = diffgram_file .image ['bucket_name' ]
140+ url = self .custom_signer_fn (blob_path , bucket_name )
141+ if url :
142+ image = imread (url )
131143 break
132144 except Exception as e :
133145 if i < MAX_RETRIES - 1 :
0 commit comments