@@ -254,6 +254,75 @@ def a_function():
254254 self .mod3 .read (),
255255 )
256256
257+ def test_adding_imports_preferred_import_style_is_normal_import (self ) -> None :
258+ self .project .prefs .imports .preferred_import_style = "normal-import"
259+ self .origin_module .write (dedent ("""\
260+ class AClass(object):
261+ pass
262+ def a_function():
263+ pass
264+ """ ))
265+ self .mod3 .write (dedent ("""\
266+ import origin_module
267+ a_var = origin_module.AClass()
268+ origin_module.a_function()""" ))
269+ # Move to destination_module_in_pkg which is in a different package
270+ self ._move (self .origin_module , self .origin_module .read ().index ("AClass" ) + 1 , self .destination_module_in_pkg )
271+ self .assertEqual (
272+ dedent ("""\
273+ import origin_module
274+ import pkg.destination_module_in_pkg
275+ a_var = pkg.destination_module_in_pkg.AClass()
276+ origin_module.a_function()""" ),
277+ self .mod3 .read (),
278+ )
279+
280+ def test_adding_imports_preferred_import_style_is_from_module (self ) -> None :
281+ self .project .prefs .imports .preferred_import_style = "from-module"
282+ self .origin_module .write (dedent ("""\
283+ class AClass(object):
284+ pass
285+ def a_function():
286+ pass
287+ """ ))
288+ self .mod3 .write (dedent ("""\
289+ import origin_module
290+ a_var = origin_module.AClass()
291+ origin_module.a_function()""" ))
292+ # Move to destination_module_in_pkg which is in a different package
293+ self ._move (self .origin_module , self .origin_module .read ().index ("AClass" ) + 1 , self .destination_module_in_pkg )
294+ self .assertEqual (
295+ dedent ("""\
296+ import origin_module
297+ from pkg import destination_module_in_pkg
298+ a_var = destination_module_in_pkg.AClass()
299+ origin_module.a_function()""" ),
300+ self .mod3 .read (),
301+ )
302+
303+ def test_adding_imports_preferred_import_style_is_from_global (self ) -> None :
304+ self .project .prefs .imports .preferred_import_style = "from-global"
305+ self .origin_module .write (dedent ("""\
306+ class AClass(object):
307+ pass
308+ def a_function():
309+ pass
310+ """ ))
311+ self .mod3 .write (dedent ("""\
312+ import origin_module
313+ a_var = origin_module.AClass()
314+ origin_module.a_function()""" ))
315+ # Move to destination_module_in_pkg which is in a different package
316+ self ._move (self .origin_module , self .origin_module .read ().index ("AClass" ) + 1 , self .destination_module_in_pkg )
317+ self .assertEqual (
318+ dedent ("""\
319+ import origin_module
320+ from pkg.destination_module_in_pkg import AClass
321+ a_var = AClass()
322+ origin_module.a_function()""" ),
323+ self .mod3 .read (),
324+ )
325+
257326 def test_adding_imports_noprefer_from_module (self ) -> None :
258327 self .project .prefs ["prefer_module_from_imports" ] = False
259328 self .origin_module .write (dedent ("""\
0 commit comments