Skip to content

Commit 5d44a98

Browse files
committed
fix: support generic class inheritance with type arguments
Fixes #1929 - Classes can now extend generic classes with specific type parameters and fields are properly resolved. Changes: - Parser: Support doc.type.sign for class extends (e.g., Bar: Foo<integer>) - Parser: Include doc.field in bindGeneric() to convert generic type names - VM: Add resolveGenericField() to resolve generic parameters in field types - VM: Handle doc.type.sign extends in class field inheritance search - VM: Export vm.cloneObject() for generic type resolution Example that now works: ---@Class Foo<T> ---@field a T ---@Class Bar: Foo<integer> local x ---@type Bar local what = x.a -- Now infers as 'integer' instead of 'unknown'
1 parent eacc3d8 commit 5d44a98

4 files changed

Lines changed: 129 additions & 3 deletions

File tree

script/parser/luadoc.lua

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1030,6 +1030,12 @@ local docSwitch = util.switch()
10301030
}
10311031
return result
10321032
end
1033+
if extend.type == 'doc.extends.name' then
1034+
local signResult = parseTypeUnitSign(result, extend)
1035+
if signResult then
1036+
extend = signResult
1037+
end
1038+
end
10331039
result.extends[#result.extends+1] = extend
10341040
result.finish = getFinish()
10351041
if not checkToken('symbol', ',', 1) then
@@ -1850,7 +1856,8 @@ local function bindGeneric(binded)
18501856
or doc.type == 'doc.return'
18511857
or doc.type == 'doc.type'
18521858
or doc.type == 'doc.class'
1853-
or doc.type == 'doc.alias' then
1859+
or doc.type == 'doc.alias'
1860+
or doc.type == 'doc.field' then
18541861
guide.eachSourceType(doc, 'doc.type.name', function (src)
18551862
local name = src[1]
18561863
if generics[name] then

script/vm/compiler.lua

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,56 @@ local function searchLiteralFieldFromTable(source, key, callback)
234234
end
235235
end
236236

237+
---@param uri uri
238+
---@param classGlobal vm.global
239+
---@param field parser.object
240+
---@param signs parser.object[]
241+
---@return parser.object?
242+
local function resolveGenericField(uri, classGlobal, field, signs)
243+
if field.type ~= 'doc.field' then
244+
return nil
245+
end
246+
if not field.extends then
247+
return nil
248+
end
249+
local hasGeneric = false
250+
guide.eachSourceType(field.extends, 'doc.generic.name', function ()
251+
hasGeneric = true
252+
end)
253+
if not hasGeneric then
254+
return nil
255+
end
256+
for _, set in ipairs(classGlobal:getSets(uri)) do
257+
if set.type == 'doc.class' and set.signs then
258+
local resolved = {}
259+
for i, signName in ipairs(set.signs) do
260+
local signType = signs[i]
261+
if signType and signName[1] then
262+
local signNode = vm.compileNode(signType)
263+
resolved[signName[1]] = signNode
264+
end
265+
end
266+
if next(resolved) then
267+
local newExtends = vm.cloneObject(field.extends, resolved)
268+
if newExtends then
269+
return {
270+
type = field.type,
271+
start = field.start,
272+
finish = field.finish,
273+
parent = field.parent,
274+
field = field.field,
275+
extends = newExtends,
276+
visible = field.visible,
277+
optional = field.optional,
278+
}
279+
end
280+
end
281+
break
282+
end
283+
end
284+
return nil
285+
end
286+
237287
local searchFieldSwitch = util.switch()
238288
: case 'table'
239289
: call(function (_suri, source, key, pushResult)
@@ -357,7 +407,16 @@ local searchFieldSwitch = util.switch()
357407
if not globalVar then
358408
return
359409
end
360-
vm.getClassFields(suri, globalVar, key, pushResult)
410+
vm.getClassFields(suri, globalVar, key, function (field, isMark)
411+
if source.signs then
412+
local newField = resolveGenericField(suri, globalVar, field, source.signs)
413+
if newField then
414+
pushResult(newField, isMark)
415+
return
416+
end
417+
end
418+
pushResult(field, isMark)
419+
end)
361420
end)
362421
: case 'global'
363422
: call(function (suri, node, key, pushResult)
@@ -565,14 +624,21 @@ function vm.getClassFields(suri, object, key, pushResult)
565624

566625
for _, set in ipairs(sets) do
567626
if set.type == 'doc.class' then
568-
-- look into extends(if field not found)
569627
if not searchedFields[key] and set.extends then
570628
for _, extend in ipairs(set.extends) do
571629
if extend.type == 'doc.extends.name' then
572630
local extendType = vm.getGlobal('type', extend[1])
573631
if extendType then
574632
searchClass(extendType, searchedFields)
575633
end
634+
elseif extend.type == 'doc.type.sign' then
635+
searchFieldSwitch(extend.type, suri, extend, key, function (field, isMark)
636+
local fieldKey = guide.getKeyName(field)
637+
if fieldKey and not searchedFields[fieldKey] then
638+
hasFounded[fieldKey] = true
639+
pushResult(field, isMark)
640+
end
641+
end)
576642
end
577643
end
578644
end

script/vm/generic.lua

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,3 +173,10 @@ function vm.createGeneric(proto, sign)
173173
}, mt)
174174
return generic
175175
end
176+
177+
---@param source vm.object?
178+
---@param resolved? table<string, vm.node>
179+
---@return vm.object?
180+
function vm.cloneObject(source, resolved)
181+
return cloneObject(source, resolved)
182+
end

test/type_inference/common.lua

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4909,3 +4909,49 @@ function f(...args)
49094909
print(<?args?>)
49104910
end
49114911
]]
4912+
4913+
TEST 'integer' [[
4914+
---@class Foo<T>
4915+
---@field a T
4916+
4917+
---@class Bar: Foo<integer>
4918+
4919+
---@type Bar
4920+
local x
4921+
local <?what?> = x.a
4922+
]]
4923+
4924+
TEST 'string' [[
4925+
---@class GenericBase<T>
4926+
---@field value T
4927+
4928+
---@class StringHolder: GenericBase<string>
4929+
4930+
---@type StringHolder
4931+
local holder
4932+
local <?v?> = holder.value
4933+
]]
4934+
4935+
TEST 'boolean' [[
4936+
---@class Container<K, V>
4937+
---@field key K
4938+
---@field val V
4939+
4940+
---@class BoolContainer: Container<string, boolean>
4941+
4942+
---@type BoolContainer
4943+
local c
4944+
local <?b?> = c.val
4945+
]]
4946+
4947+
TEST 'string' [[
4948+
---@class Container<K, V>
4949+
---@field key K
4950+
---@field val V
4951+
4952+
---@class BoolContainer: Container<string, boolean>
4953+
4954+
---@type BoolContainer
4955+
local c
4956+
local <?k?> = c.key
4957+
]]

0 commit comments

Comments
 (0)