66
77
88def parse_dss_managed_folder_uri (uri ):
9- """Parse an S3 URI, returning (bucket, path)"""
109 parsed = urllib .parse .urlparse (uri )
1110 if parsed .scheme != "dss-managed-folder" :
1211 raise Exception ("Not a DSS Managed Folder URI: %s" % uri )
13- return os .path .normpath (parsed .path )
14-
12+ if not parsed .netloc or parsed .netloc == '.' :
13+ raise Exception ("Could not find a managed folder id in URI: %s" % uri )
14+ return parsed
1515
1616class PluginDSSManagedFolderArtifactRepository :
1717
1818 def __init__ (self , artifact_uri ):
19- self .base_artifact_path = parse_dss_managed_folder_uri (artifact_uri )
2019 if os .environ .get ("DSS_MLFLOW_APIKEY" ) is not None :
2120 self .client = DSSClient (
2221 os .environ .get ("DSS_MLFLOW_HOST" ),
@@ -28,14 +27,19 @@ def __init__(self, artifact_uri):
2827 internal_ticket = os .environ .get ("DSS_MLFLOW_INTERNAL_TICKET" )
2928 )
3029 self .project = self .client .get_project (os .environ .get ("DSS_MLFLOW_PROJECTKEY" ))
31- managed_folders = [
32- x ["id" ] for x in self .project .list_managed_folders ()
33- if x ["name" ] == os .environ .get ("DSS_MLFLOW_MANAGED_FOLDER" )
34- ]
35- if len (managed_folders ) > 0 :
36- self .managed_folder = self .project .get_managed_folder (managed_folders [0 ])
30+ parsed_uri = parse_dss_managed_folder_uri (artifact_uri )
31+ self .managed_folder = self .__get_managed_folder (parsed_uri .netloc )
32+ self .base_artifact_path = os .path .normpath (parsed_uri .path )
33+
34+ def __get_managed_folder (self , managed_folder_smart_id ):
35+ chunks = managed_folder_smart_id .split ('.' )
36+ if len (chunks ) == 1 :
37+ return self .project .get_managed_folder (chunks [0 ])
38+ elif len (chunks ) == 2 :
39+ project = self .client .get_project (chunks [0 ])
40+ return project .get_managed_folder (chunks [1 ])
3741 else :
38- self . managed_folder = self . project . create_managed_folder ( os . environ . get ( "DSS_MLFLOW_MANAGED_FOLDER" ) )
42+ raise Exception ( "Invalid managed folder id: %s" % managed_folder_smart_id )
3943
4044 def log_artifact (self , local_file , artifact_path = None ):
4145 """
0 commit comments