From d7f2758a2549cc5dc300b1da3aa20cffc1749ecc Mon Sep 17 00:00:00 2001 From: rubin Date: Sun, 19 Apr 2026 16:10:45 -0300 Subject: [PATCH 1/4] fix(mysql): map division opcodes to the correct operators --- internal/engine/dolphin/convert.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/internal/engine/dolphin/convert.go b/internal/engine/dolphin/convert.go index 1f68358ce4..68e248e198 100644 --- a/internal/engine/dolphin/convert.go +++ b/internal/engine/dolphin/convert.go @@ -137,6 +137,8 @@ func opToName(o opcode.Op) string { // case opcode.BitNeg: // case opcode.Case: // case opcode.Div: + case opcode.Div: + return "/" case opcode.EQ: return "=" case opcode.GE: @@ -145,7 +147,7 @@ func opToName(o opcode.Op) string { return ">" // case opcode.In: case opcode.IntDiv: - return "/" + return "div" // case opcode.IsFalsity: // case opcode.IsNull: // case opcode.IsTruth: From b645325dac6290ad56f5b0eb96fdee5dcb17306d Mon Sep 17 00:00:00 2001 From: rubin Date: Sun, 19 Apr 2026 16:11:05 -0300 Subject: [PATCH 2/4] feat(mysql): infer types for simple numeric expressions --- internal/compiler/infer_expr_type.go | 273 +++++++++++++++++++++++++++ 1 file changed, 273 insertions(+) create mode 100644 internal/compiler/infer_expr_type.go diff --git a/internal/compiler/infer_expr_type.go b/internal/compiler/infer_expr_type.go new file mode 100644 index 0000000000..663fdec624 --- /dev/null +++ b/internal/compiler/infer_expr_type.go @@ -0,0 +1,273 @@ +package compiler + +import ( + "github.com/sqlc-dev/sqlc/internal/config" + "github.com/sqlc-dev/sqlc/internal/sql/ast" +) + +// +// ============================== +// Internal Type System +// ============================== +// + +type Kind int + +const ( + KindUnknown Kind = iota // inference not supported + KindInt + KindFloat + KindDecimal + KindAny +) + +type Type struct { + Kind Kind + NotNull bool + Valid bool // explicit signal: inference succeeded +} + +func unknownType() Type { + return Type{Kind: KindUnknown, Valid: false} +} + +// +// ============================== +// Entry Point +// ============================== +// + +func (c *Compiler) inferExprType(node ast.Node, tables []*Table) *Column { + if node == nil { + return nil + } + + switch c.conf.Engine { + case config.EngineMySQL: + t := c.inferMySQLExpr(node, tables) + return c.mysqlTypeToColumn(t) + + // case config.EnginePostgreSQL: + // t := c.inferPostgresExpr(node, tables) + // return c.postgresTypeToColumn(t) + + default: + return nil + } +} + +// +// ============================== +// MySQL Inference +// ============================== +// + +func (c *Compiler) inferMySQLExpr(node ast.Node, tables []*Table) Type { + switch n := node.(type) { + case *ast.ColumnRef: + return c.inferMySQLColumnRef(n, tables) + + case *ast.A_Const: + return inferConst(n) + + case *ast.TypeCast: + return c.inferMySQLTypeCast(n, tables) + + case *ast.A_Expr: + return c.inferMySQLBinary(n, tables) + + default: + return unknownType() + } +} + +// +// ------------------------------ +// Leaf nodes +// ------------------------------ +// + +func (c *Compiler) inferMySQLColumnRef(ref *ast.ColumnRef, tables []*Table) Type { + cols, err := outputColumnRefs(&ast.ResTarget{}, tables, ref) + if err != nil || len(cols) == 0 { + return unknownType() + } + + col := cols[0] + + return Type{ + Kind: mapMySQLKind(col.DataType), + NotNull: col.NotNull, + Valid: true, + } +} + +func inferConst(node *ast.A_Const) Type { + if node == nil || node.Val == nil { + return unknownType() + } + + switch node.Val.(type) { + case *ast.Integer: + return Type{Kind: KindInt, NotNull: true, Valid: true} + + case *ast.Float: + return Type{Kind: KindFloat, NotNull: true, Valid: true} + + case *ast.Null: + return Type{Kind: KindAny, NotNull: false, Valid: true} + + default: + return unknownType() + } +} + +func (c *Compiler) inferMySQLTypeCast(node *ast.TypeCast, tables []*Table) Type { + if node == nil || node.TypeName == nil { + return unknownType() + } + + base := toColumn(node.TypeName) + if base == nil { + return unknownType() + } + + arg := c.inferMySQLExpr(node.Arg, tables) + + t := Type{ + Kind: mapMySQLKind(base.DataType), + Valid: true, + } + + // propagate nullability + if arg.Valid { + t.NotNull = arg.NotNull + } + + // explicit NULL literal + if constant, ok := node.Arg.(*ast.A_Const); ok { + if _, isNull := constant.Val.(*ast.Null); isNull { + t.NotNull = false + } + } + + return t +} + +// +// ------------------------------ +// Binary expressions +// ------------------------------ +// + +func (c *Compiler) inferMySQLBinary(node *ast.A_Expr, tables []*Table) Type { + op := joinOperator(node) + + left := c.inferMySQLExpr(node.Lexpr, tables) + right := c.inferMySQLExpr(node.Rexpr, tables) + + if !left.Valid || !right.Valid { + return unknownType() + } + + // NOTE: only normal division ("/") is supported for now. + // Unsupported operators intentionally fall back to the existing behavior. + return promoteMySQLNumeric(op, left, right) +} + +// +// ============================== +// Promotion Rules (MySQL-specific for now) +// ============================== +// + +// promoteMySQLNumeric applies simplified numeric promotion rules for MySQL. +// It currently only supports "/" and intentionally falls back for other operators. +func promoteMySQLNumeric(op string, a, b Type) Type { + notNull := a.NotNull && b.NotNull + + switch op { + case "/": + if a.Kind == KindFloat || b.Kind == KindFloat { + return Type{ + Kind: KindFloat, + NotNull: notNull, + Valid: true, + } + } + + return Type{ + Kind: KindDecimal, + NotNull: notNull, + Valid: true, + } + } + + return unknownType() +} + +// +// ============================== +// Engine-specific Mapping +// ============================== +// + +func (c *Compiler) mysqlTypeToColumn(t Type) *Column { + if !t.Valid { + return nil + } + + col := &Column{ + NotNull: t.NotNull, + } + + switch t.Kind { + case KindInt: + col.DataType = "int" + + case KindFloat: + col.DataType = "float" + + case KindDecimal: + col.DataType = "decimal" + + default: + col.DataType = "any" + } + + return col +} + +func mapMySQLKind(dt string) Kind { + switch dt { + case "int", "integer", "bigint", "smallint": + return KindInt + + case "float", "double", "real": + return KindFloat + + case "decimal", "numeric": + return KindDecimal + + default: + return KindUnknown + } +} + +// +// ============================== +// AST helpers +// ============================== +// + +func joinOperator(node *ast.A_Expr) string { + if node == nil || node.Name == nil || len(node.Name.Items) == 0 { + return "" + } + + if s, ok := node.Name.Items[0].(*ast.String); ok { + return s.Str + } + + return "" +} From d2b9c0b87489308003347154c9dd4f20e4a8076e Mon Sep 17 00:00:00 2001 From: rubin Date: Sun, 19 Apr 2026 16:11:16 -0300 Subject: [PATCH 3/4] fix(mysql): use expression inference while preserving fallback behavior --- internal/compiler/output_columns.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/internal/compiler/output_columns.go b/internal/compiler/output_columns.go index dbd486359a..f4e283ea8a 100644 --- a/internal/compiler/output_columns.go +++ b/internal/compiler/output_columns.go @@ -155,7 +155,13 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er // TODO: Generate a name for these operations cols = append(cols, &Column{Name: name, DataType: "bool", NotNull: true}) case lang.IsMathematicalOperator(op): - cols = append(cols, &Column{Name: name, DataType: "int", NotNull: true}) + if inferredCol := c.inferExprType(n, tables); inferredCol != nil { + inferredCol.Name = name + inferredCol.skipTableRequiredCheck = true + cols = append(cols, inferredCol) + } else { + cols = append(cols, &Column{Name: name, DataType: "int", NotNull: true}) + } default: cols = append(cols, &Column{Name: name, DataType: "any", NotNull: false}) } From d6c15e2b1fbfc4abbec8171ea9b93f6b619d9b31 Mon Sep 17 00:00:00 2001 From: rubin Date: Thu, 23 Apr 2026 14:48:32 -0300 Subject: [PATCH 4/4] fix(mysql): correct TypeCast inference and add end-to-end tests --- internal/compiler/infer_expr_type.go | 7 +- .../mysql/go/db.go | 31 +++++ .../mysql/go/models.go | 16 +++ .../mysql/go/query.sql.go | 119 ++++++++++++++++++ .../mysql/query.sql | 11 ++ .../mysql/schema.sql | 6 + .../mysql/sqlc.json | 14 +++ 7 files changed, 201 insertions(+), 3 deletions(-) create mode 100644 internal/endtoend/testdata/mysql_expression_type_inference/mysql/go/db.go create mode 100644 internal/endtoend/testdata/mysql_expression_type_inference/mysql/go/models.go create mode 100644 internal/endtoend/testdata/mysql_expression_type_inference/mysql/go/query.sql.go create mode 100644 internal/endtoend/testdata/mysql_expression_type_inference/mysql/query.sql create mode 100644 internal/endtoend/testdata/mysql_expression_type_inference/mysql/schema.sql create mode 100644 internal/endtoend/testdata/mysql_expression_type_inference/mysql/sqlc.json diff --git a/internal/compiler/infer_expr_type.go b/internal/compiler/infer_expr_type.go index 663fdec624..67240bf210 100644 --- a/internal/compiler/infer_expr_type.go +++ b/internal/compiler/infer_expr_type.go @@ -127,15 +127,16 @@ func (c *Compiler) inferMySQLTypeCast(node *ast.TypeCast, tables []*Table) Type return unknownType() } - base := toColumn(node.TypeName) - if base == nil { + // MySQL populates TypeName.Name directly; toColumn reads TypeName.Names (Postgres-style). + kind := mapMySQLKind(node.TypeName.Name) + if kind == KindUnknown { return unknownType() } arg := c.inferMySQLExpr(node.Arg, tables) t := Type{ - Kind: mapMySQLKind(base.DataType), + Kind: kind, Valid: true, } diff --git a/internal/endtoend/testdata/mysql_expression_type_inference/mysql/go/db.go b/internal/endtoend/testdata/mysql_expression_type_inference/mysql/go/db.go new file mode 100644 index 0000000000..80dd6ab1f6 --- /dev/null +++ b/internal/endtoend/testdata/mysql_expression_type_inference/mysql/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/mysql_expression_type_inference/mysql/go/models.go b/internal/endtoend/testdata/mysql_expression_type_inference/mysql/go/models.go new file mode 100644 index 0000000000..c5b0ebf03c --- /dev/null +++ b/internal/endtoend/testdata/mysql_expression_type_inference/mysql/go/models.go @@ -0,0 +1,16 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 + +package querytest + +import ( + "database/sql" +) + +type Metric struct { + ID int32 + Value sql.NullFloat64 + Count int32 + Ratio sql.NullFloat64 +} diff --git a/internal/endtoend/testdata/mysql_expression_type_inference/mysql/go/query.sql.go b/internal/endtoend/testdata/mysql_expression_type_inference/mysql/go/query.sql.go new file mode 100644 index 0000000000..316a3a5725 --- /dev/null +++ b/internal/endtoend/testdata/mysql_expression_type_inference/mysql/go/query.sql.go @@ -0,0 +1,119 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 +// source: query.sql + +package querytest + +import ( + "context" + "database/sql" +) + +const floatDivByConst = `-- name: FloatDivByConst :many +SELECT (value / 1024) AS scaled_value FROM metrics +` + +func (q *Queries) FloatDivByConst(ctx context.Context) ([]sql.NullFloat64, error) { + rows, err := q.db.QueryContext(ctx, floatDivByConst) + if err != nil { + return nil, err + } + defer rows.Close() + var items []sql.NullFloat64 + for rows.Next() { + var scaled_value sql.NullFloat64 + if err := rows.Scan(&scaled_value); err != nil { + return nil, err + } + items = append(items, scaled_value) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const floatDivByFloat = `-- name: FloatDivByFloat :many +SELECT (value / ratio) AS proportion FROM metrics +` + +func (q *Queries) FloatDivByFloat(ctx context.Context) ([]sql.NullFloat64, error) { + rows, err := q.db.QueryContext(ctx, floatDivByFloat) + if err != nil { + return nil, err + } + defer rows.Close() + var items []sql.NullFloat64 + for rows.Next() { + var proportion sql.NullFloat64 + if err := rows.Scan(&proportion); err != nil { + return nil, err + } + items = append(items, proportion) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const intDivByConst = `-- name: IntDivByConst :many +SELECT (count / 10) AS avg_count FROM metrics +` + +func (q *Queries) IntDivByConst(ctx context.Context) ([]string, error) { + rows, err := q.db.QueryContext(ctx, intDivByConst) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var avg_count string + if err := rows.Scan(&avg_count); err != nil { + return nil, err + } + items = append(items, avg_count) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const notNullFloatDivByConst = `-- name: NotNullFloatDivByConst :many +SELECT (CAST(value AS FLOAT) / 1024) AS scaled FROM metrics +` + +func (q *Queries) NotNullFloatDivByConst(ctx context.Context) ([]sql.NullFloat64, error) { + rows, err := q.db.QueryContext(ctx, notNullFloatDivByConst) + if err != nil { + return nil, err + } + defer rows.Close() + var items []sql.NullFloat64 + for rows.Next() { + var scaled sql.NullFloat64 + if err := rows.Scan(&scaled); err != nil { + return nil, err + } + items = append(items, scaled) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/mysql_expression_type_inference/mysql/query.sql b/internal/endtoend/testdata/mysql_expression_type_inference/mysql/query.sql new file mode 100644 index 0000000000..c1c8f1200f --- /dev/null +++ b/internal/endtoend/testdata/mysql_expression_type_inference/mysql/query.sql @@ -0,0 +1,11 @@ +-- name: FloatDivByConst :many +SELECT (value / 1024) AS scaled_value FROM metrics; + +-- name: IntDivByConst :many +SELECT (count / 10) AS avg_count FROM metrics; + +-- name: FloatDivByFloat :many +SELECT (value / ratio) AS proportion FROM metrics; + +-- name: NotNullFloatDivByConst :many +SELECT (CAST(value AS FLOAT) / 1024) AS scaled FROM metrics; diff --git a/internal/endtoend/testdata/mysql_expression_type_inference/mysql/schema.sql b/internal/endtoend/testdata/mysql_expression_type_inference/mysql/schema.sql new file mode 100644 index 0000000000..6c58791222 --- /dev/null +++ b/internal/endtoend/testdata/mysql_expression_type_inference/mysql/schema.sql @@ -0,0 +1,6 @@ +CREATE TABLE metrics ( + id INT NOT NULL PRIMARY KEY, + value FLOAT NULL, + count INT NOT NULL, + ratio DOUBLE NULL +); diff --git a/internal/endtoend/testdata/mysql_expression_type_inference/mysql/sqlc.json b/internal/endtoend/testdata/mysql_expression_type_inference/mysql/sqlc.json new file mode 100644 index 0000000000..7dabfeef72 --- /dev/null +++ b/internal/endtoend/testdata/mysql_expression_type_inference/mysql/sqlc.json @@ -0,0 +1,14 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "sql_package": "database/sql", + "sql_driver": "github.com/go-sql-driver/mysql", + "engine": "mysql", + "name": "querytest", + "schema": "schema.sql", + "queries": "query.sql" + } + ] +}