Skip to content

Commit a4bbb9e

Browse files
lieryannicoolas25
andcommitted
Implement preferred_import_style
This is a configuration option to select the import style that rope will use when adding new imports. Co-authored-by: Nicolas Zermati <nicoolas25@gmail.com>
1 parent 4bba655 commit a4bbb9e

2 files changed

Lines changed: 75 additions & 1 deletion

File tree

rope/refactor/importutils/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import rope.base.codeanalyze
99
import rope.base.evaluate
1010
from rope.base import libutils
11+
from rope.base.prefs import get_preferred_import_style
12+
from rope.base.prefs import ImportStyle
1113
from rope.base.change import ChangeContents, ChangeSet
1214
from rope.refactor import occurrences, rename
1315
from rope.refactor.importutils import actions, module_imports
@@ -299,20 +301,23 @@ def get_module_imports(project, pymodule):
299301

300302

301303
def add_import(project, pymodule, module_name, name=None):
304+
preferred_import_style = get_preferred_import_style(project.prefs)
302305
imports = get_module_imports(project, pymodule)
303306
candidates = []
304307
names = []
305308
selected_import = None
306309
# from mod import name
307310
if name is not None:
308311
from_import = FromImport(module_name, 0, [(name, None)])
312+
if preferred_import_style == ImportStyle.from_global:
313+
selected_import = from_import
309314
names.append(name)
310315
candidates.append(from_import)
311316
# from pkg import mod
312317
if "." in module_name:
313318
pkg, mod = module_name.rsplit(".", 1)
314319
from_import = FromImport(pkg, 0, [(mod, None)])
315-
if project.prefs.get("prefer_module_from_imports"):
320+
if preferred_import_style == ImportStyle.from_module:
316321
selected_import = from_import
317322
candidates.append(from_import)
318323
if name:

ropetest/refactor/movetest.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,75 @@ def a_function():
254254
self.mod3.read(),
255255
)
256256

257+
def test_adding_imports_preferred_import_style_is_normal_import(self) -> None:
258+
self.project.prefs.imports.preferred_import_style = "normal-import"
259+
self.origin_module.write(dedent("""\
260+
class AClass(object):
261+
pass
262+
def a_function():
263+
pass
264+
"""))
265+
self.mod3.write(dedent("""\
266+
import origin_module
267+
a_var = origin_module.AClass()
268+
origin_module.a_function()"""))
269+
# Move to destination_module_in_pkg which is in a different package
270+
self._move(self.origin_module, self.origin_module.read().index("AClass") + 1, self.destination_module_in_pkg)
271+
self.assertEqual(
272+
dedent("""\
273+
import origin_module
274+
import pkg.destination_module_in_pkg
275+
a_var = pkg.destination_module_in_pkg.AClass()
276+
origin_module.a_function()"""),
277+
self.mod3.read(),
278+
)
279+
280+
def test_adding_imports_preferred_import_style_is_from_module(self) -> None:
281+
self.project.prefs.imports.preferred_import_style = "from-module"
282+
self.origin_module.write(dedent("""\
283+
class AClass(object):
284+
pass
285+
def a_function():
286+
pass
287+
"""))
288+
self.mod3.write(dedent("""\
289+
import origin_module
290+
a_var = origin_module.AClass()
291+
origin_module.a_function()"""))
292+
# Move to destination_module_in_pkg which is in a different package
293+
self._move(self.origin_module, self.origin_module.read().index("AClass") + 1, self.destination_module_in_pkg)
294+
self.assertEqual(
295+
dedent("""\
296+
import origin_module
297+
from pkg import destination_module_in_pkg
298+
a_var = destination_module_in_pkg.AClass()
299+
origin_module.a_function()"""),
300+
self.mod3.read(),
301+
)
302+
303+
def test_adding_imports_preferred_import_style_is_from_global(self) -> None:
304+
self.project.prefs.imports.preferred_import_style = "from-global"
305+
self.origin_module.write(dedent("""\
306+
class AClass(object):
307+
pass
308+
def a_function():
309+
pass
310+
"""))
311+
self.mod3.write(dedent("""\
312+
import origin_module
313+
a_var = origin_module.AClass()
314+
origin_module.a_function()"""))
315+
# Move to destination_module_in_pkg which is in a different package
316+
self._move(self.origin_module, self.origin_module.read().index("AClass") + 1, self.destination_module_in_pkg)
317+
self.assertEqual(
318+
dedent("""\
319+
import origin_module
320+
from pkg.destination_module_in_pkg import AClass
321+
a_var = AClass()
322+
origin_module.a_function()"""),
323+
self.mod3.read(),
324+
)
325+
257326
def test_adding_imports_noprefer_from_module(self) -> None:
258327
self.project.prefs["prefer_module_from_imports"] = False
259328
self.origin_module.write(dedent("""\

0 commit comments

Comments
 (0)