@@ -227,21 +227,20 @@ def promote(
227227 on_complete: A callback to call on each successfully promoted snapshot.
228228 """
229229
230- gateway_by_schema : t .Dict [t .Any , str ] = {}
231- tables : t .List [t .Any ] = []
230+ tables_by_gateway : t .Dict [t .Union [str , None ], t .List [exp .Table ]] = defaultdict (list )
232231 for snapshot in target_snapshots :
233232 if snapshot .is_model and not snapshot .is_symbolic :
233+ gateway = (
234+ snapshot .model_gateway if environment_naming_info .gateway_managed else None
235+ )
236+ adapter = self ._get_adapter (gateway )
234237 table = snapshot .qualified_view_name .table_for_environment (
235- environment_naming_info ,
236- dialect = self ._get_adapter (snapshot .model_gateway ).dialect
237- if environment_naming_info .gateway_managed
238- else self .adapter .dialect ,
238+ environment_naming_info , dialect = adapter .dialect
239239 )
240- tables .append (table )
241- if environment_naming_info .gateway_managed :
242- table_schema = d .schema_ (table .db , catalog = table .catalog )
243- gateway_by_schema [table_schema ] = snapshot .model_gateway or ""
244- self ._create_schemas (tables = tables , gateways = gateway_by_schema )
240+ tables_by_gateway [gateway ].append (table )
241+
242+ for gateway , tables in tables_by_gateway .items ():
243+ self ._create_schemas (tables = tables , gateway = gateway )
245244
246245 deployability_index = deployability_index or DeployabilityIndex .all_deployable ()
247246 with self .concurrent_context ():
@@ -301,8 +300,9 @@ def create(
301300 allow_destructive_snapshots: Set of snapshots that are allowed to have destructive schema changes.
302301 """
303302 snapshots_with_table_names = defaultdict (set )
304- tables_by_schema = defaultdict (set )
305- gateway_by_schema : t .Dict [exp .Table , str ] = {}
303+ tables_by_gateway_and_schema : t .Dict [t .Union [str , None ], t .Dict [exp .Table , set [str ]]] = (
304+ defaultdict (lambda : defaultdict (set ))
305+ )
306306 table_deployability : t .Dict [str , bool ] = {}
307307 allow_destructive_snapshots = allow_destructive_snapshots or set ()
308308
@@ -324,24 +324,32 @@ def create(
324324 snapshots_with_table_names [snapshot ].add (table .name )
325325 table_deployability [table .name ] = is_deployable
326326 table_schema = d .schema_ (table .db , catalog = table .catalog )
327- tables_by_schema [table_schema ].add (table .name )
328- gateway_by_schema [table_schema ] = snapshot .model .gateway or ""
327+ tables_by_gateway_and_schema [snapshot .model_gateway ][table_schema ].add (table .name )
329328
330- def _get_data_objects (schema : exp .Table , gateway : t .Optional [str ] = None ) -> t .Set [str ]:
329+ def _get_data_objects (
330+ schema : exp .Table ,
331+ object_names : t .Optional [t .Set [str ]] = None ,
332+ gateway : t .Optional [str ] = None ,
333+ ) -> t .Set [str ]:
331334 logger .info ("Listing data objects in schema %s" , schema .sql ())
332- objs = self .get_adapter (gateway ).get_data_objects (schema , tables_by_schema [ schema ] )
335+ objs = self ._get_adapter (gateway ).get_data_objects (schema , object_names )
333336 return {obj .name for obj in objs }
334337
335338 with self .concurrent_context ():
336- existing_objects = {
337- obj
338- for objs in concurrent_apply_to_values (
339- list (tables_by_schema ),
340- lambda s : _get_data_objects (s , gateway_by_schema [s ]),
341- self .ddl_concurrent_tasks ,
342- )
343- for obj in objs
344- }
339+ existing_objects : t .Set [str ] = set ()
340+ for gateway , tables_by_schema in tables_by_gateway_and_schema .items ():
341+ objs_for_gateway = {
342+ obj
343+ for objs in concurrent_apply_to_values (
344+ list (tables_by_schema ),
345+ lambda s : _get_data_objects (
346+ schema = s , object_names = tables_by_schema .get (s ), gateway = gateway
347+ ),
348+ self .ddl_concurrent_tasks ,
349+ )
350+ for obj in objs
351+ }
352+ existing_objects .update (objs_for_gateway )
345353
346354 snapshots_to_create = []
347355 target_deployability_flags : t .Dict [str , t .List [bool ]] = defaultdict (list )
@@ -359,7 +367,10 @@ def _get_data_objects(schema: exp.Table, gateway: t.Optional[str] = None) -> t.S
359367 return
360368 if on_start :
361369 on_start (len (snapshots_to_create ))
362- self ._create_schemas (tables_by_schema , gateway_by_schema )
370+
371+ for gateway , tables_by_schema in tables_by_gateway_and_schema .items ():
372+ self ._create_schemas (tables = tables_by_schema , gateway = gateway )
373+
363374 self ._create_snapshots (
364375 snapshots_to_create = snapshots_to_create ,
365376 snapshots = snapshots ,
@@ -1072,7 +1083,7 @@ def _audit(
10721083 def _create_schemas (
10731084 self ,
10741085 tables : t .Iterable [t .Union [exp .Table , str ]],
1075- gateways : t .Optional [t . Dict [ exp . Table , str ] ] = None ,
1086+ gateway : t .Optional [str ] = None ,
10761087 ) -> None :
10771088 table_exprs = [exp .to_table (t ) for t in tables ]
10781089 unique_schemas = {(t .args ["db" ], t .args .get ("catalog" )) for t in table_exprs if t and t .db }
@@ -1081,7 +1092,7 @@ def _create_schemas(
10811092 for schema_name , catalog in unique_schemas :
10821093 schema = schema_ (schema_name , catalog )
10831094 logger .info ("Creating schema '%s'" , schema )
1084- adapter = self .get_adapter ( gateways . get ( schema )) if gateways else self . adapter
1095+ adapter = self ._get_adapter ( gateway )
10851096 adapter .create_schema (schema )
10861097
10871098 def get_adapter (self , gateway : t .Optional [str ] = None ) -> EngineAdapter :
0 commit comments