动手写ORM框架 - GeeORM第三天 记录新增和查询
源代码/数据集已上传到
Github - 7days-golang
本文是7天用Go从零实现ORM框架GeeORM 的第三篇。
实现新增(insert)记录的功能。
使用反射(reflect)将数据库的记录转换为对应的结构体实例,实现查询(select)功能。代码约150行
1 Clause 构造 SQL 语句 从第三天开始,GeeORM 需要涉及一些较为复杂的操作,例如查询操作。查询语句一般由很多个子句(clause) 构成。SELECT 语句的构成通常是这样的:
1 2 3 4 5 SELECT col1, col2, ... FROM table_name WHERE [ conditions ] GROUP BY col1 HAVING [ conditions ]
也就是说,如果想一次构造出完整的 SQL 语句是比较困难的,因此我们将构造 SQL 语句这一部分独立出来,放在子package clause 中实现。
首先在 clause/generator.go
中实现各个子句的生成规则。
day3-save-query/clause/generator.go
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 package clauseimport ( "fmt" "strings" ) type generator func (values ...interface {}) (string , []interface {}) var generators map [Type]generatorfunc init () { generators = make (map [Type]generator) generators[INSERT] = _insert generators[VALUES] = _values generators[SELECT] = _select generators[LIMIT] = _limit generators[WHERE] = _where generators[ORDERBY] = _orderBy } func genBindVars (num int ) string { var vars []string for i := 0 ; i < num; i++ { vars = append (vars, "?" ) } return strings.Join(vars, ", " ) } func _insert (values ...interface {}) (string , []interface {}) { tableName := values[0 ] fields := strings.Join(values[1 ].([]string ), "," ) return fmt.Sprintf("INSERT INTO %s (%v)" , tableName, fields), []interface {}{} } func _values (values ...interface {}) (string , []interface {}) { var bindStr string var sql strings.Builder var vars []interface {} sql.WriteString("VALUES " ) for i, value := range values { v := value.([]interface {}) if bindStr == "" { bindStr = genBindVars(len (v)) } sql.WriteString(fmt.Sprintf("(%v)" , bindStr)) if i+1 != len (values) { sql.WriteString(", " ) } vars = append (vars, v...) } return sql.String(), vars } func _select (values ...interface {}) (string , []interface {}) { tableName := values[0 ] fields := strings.Join(values[1 ].([]string ), "," ) return fmt.Sprintf("SELECT %v FROM %s" , fields, tableName), []interface {}{} } func _limit (values ...interface {}) (string , []interface {}) { return "LIMIT ?" , values } func _where (values ...interface {}) (string , []interface {}) { desc, vars := values[0 ], values[1 :] return fmt.Sprintf("WHERE %s" , desc), vars } func _orderBy (values ...interface {}) (string , []interface {}) { return fmt.Sprintf("ORDER BY %s" , values[0 ]), []interface {}{} }
然后在 clause/clause.go
中实现结构体 Clause
拼接各个独立的子句。
day3-save-query/clause/clause.go
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 package clauseimport "strings" type Clause struct { sql map [Type]string sqlVars map [Type][]interface {} } type Type int const ( INSERT Type = iota VALUES SELECT LIMIT WHERE ORDERBY ) func (c *Clause) Set (name Type, vars ...interface {}) { if c.sql == nil { c.sql = make (map [Type]string ) c.sqlVars = make (map [Type][]interface {}) } sql, vars := generators[name](vars...) c.sql[name] = sql c.sqlVars[name] = vars } func (c *Clause) Build (orders ...Type) (string , []interface {}) { var sqls []string var vars []interface {} for _, order := range orders { if sql, ok := c.sql[order]; ok { sqls = append (sqls, sql) vars = append (vars, c.sqlVars[order]...) } } return strings.Join(sqls, " " ), vars }
Set
方法根据 Type
调用对应的 generator,生成该子句对应的 SQL 语句。
Build
方法根据传入的 Type
的顺序,构造出最终的 SQL 语句。
在 clause_test.go
实现对应的测试用例:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 func testSelect (t *testing.T) { var clause Clause clause.Set(LIMIT, 3 ) clause.Set(SELECT, "User" , []string {"*" }) clause.Set(WHERE, "Name = ?" , "Tom" ) clause.Set(ORDERBY, "Age ASC" ) sql, vars := clause.Build(SELECT, WHERE, ORDERBY, LIMIT) t.Log(sql, vars) if sql != "SELECT * FROM User WHERE Name = ? ORDER BY Age ASC LIMIT ?" { t.Fatal("failed to build SQL" ) } if !reflect.DeepEqual(vars, []interface {}{"Tom" , 3 }) { t.Fatal("failed to build SQLVars" ) } } func TestClause_Build (t *testing.T) { t.Run("select" , func (t *testing.T) { testSelect(t) }) }
2 实现 Insert 功能 首先为 Session 添加成员变量 clause
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 type Session struct { db *sql.DB dialect dialect.Dialect refTable *schema.Schema clause clause.Clause sql strings.Builder sqlVars []interface {} } func (s *Session) Clear () { s.sql.Reset() s.sqlVars = nil s.clause = clause.Clause{} }
clause 已经支持生成简单的插入(INSERT) 和 查询(SELECT) 的 SQL 语句,那么紧接着我们就可以在 session 中实现对应的功能了。
INSERT 对应的 SQL 语句一般是这样的:
1 2 3 4 INSERT INTO table_name(col1, col2, col3, ...) VALUES (A1, A2, A3, ...), (B1, B2, B3, ...), ...
在 ORM 框架中期望 Insert 的调用方式如下:
1 2 3 4 s := geeorm.NewEngine("sqlite3" , "gee.db" ).NewSession() u1 := &User{Name: "Tom" , Age: 18 } u2 := &User{Name: "Sam" , Age: 25 } s.Insert(u1, u2, ...)
也就是说,我们还需要一个步骤,根据数据库中列的顺序,从对象中找到对应的值,按顺序平铺。即 u1
、u2
转换为 ("Tom", 18), ("Same", 25)
这样的格式。
因此在实现 Insert 功能之前,还需要给 Schema
新增一个函数 RecordValues
完成上述的转换。
day3-save-query/schema/schema.go
1 2 3 4 5 6 7 8 func (schema *Schema) RecordValues (dest interface {}) []interface {} { destValue := reflect.Indirect(reflect.ValueOf(dest)) var fieldValues []interface {} for _, field := range schema.Fields { fieldValues = append (fieldValues, destValue.FieldByName(field.Name).Interface()) } return fieldValues }
在 session 文件夹下新建 record.go,用于实现记录增删查改相关的代码。
day3-save-query/session/record.go
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 package sessionimport ( "geeorm/clause" "reflect" ) func (s *Session) Insert (values ...interface {}) (int64 , error) { recordValues := make ([]interface {}, 0 ) for _, value := range values { table := s.Model(value).RefTable() s.clause.Set(clause.INSERT, table.Name, table.FieldNames) recordValues = append (recordValues, table.RecordValues(value)) } s.clause.Set(clause.VALUES, recordValues...) sql, vars := s.clause.Build(clause.INSERT, clause.VALUES) result, err := s.Raw(sql, vars...).Exec() if err != nil { return 0 , err } return result.RowsAffected() }
后续所有构造 SQL 语句的方式都将与 Insert
中构造 SQL 语句的方式一致。分两步:
1)多次调用 clause.Set()
构造好每一个子句。
2)调用一次 clause.Build()
按照传入的顺序构造出最终的 SQL 语句。
构造完成后,调用 Raw().Exec()
方法执行。
3 实现 Find 功能 期望的调用方式是这样的:传入一个切片指针,查询的结果保存在切片中。
1 2 3 s := geeorm.NewEngine("sqlite3" , "gee.db" ).NewSession() var users []Users.Find(&users);
Find 功能的难点和 Insert 恰好反了过来。Insert 需要将已经存在的对象的每一个字段的值平铺开来,而 Find 则是需要根据平铺开的字段的值构造出对象。同样,也需要用到反射(reflect)。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 func (s *Session) Find (values interface {}) error { destSlice := reflect.Indirect(reflect.ValueOf(values)) destType := destSlice.Type().Elem() table := s.Model(reflect.New(destType).Elem().Interface()).RefTable() s.clause.Set(clause.SELECT, table.Name, table.FieldNames) sql, vars := s.clause.Build(clause.SELECT, clause.WHERE, clause.ORDERBY, clause.LIMIT) rows, err := s.Raw(sql, vars...).QueryRows() if err != nil { return err } for rows.Next() { dest := reflect.New(destType).Elem() var values []interface {} for _, name := range table.FieldNames { values = append (values, dest.FieldByName(name).Addr().Interface()) } if err := rows.Scan(values...); err != nil { return err } destSlice.Set(reflect.Append(destSlice, dest)) } return rows.Close() }
Find 的代码实现比较复杂,主要分为以下几步:
destSlice.Type().Elem()
获取切片的单个元素的类型 destType
,使用 reflect.New()
方法创建一个 destType
的实例,作为 Model()
的入参,映射出表结构 RefTable()
。
2)根据表结构,使用 clause 构造出 SELECT 语句,查询到所有符合条件的记录 rows
。
3)遍历每一行记录,利用反射创建 destType
的实例 dest
,将 dest
的所有字段平铺开,构造切片 values
。
4)调用 rows.Scan()
将该行记录每一列的值依次赋值给 values 中的每一个字段。
5)将 dest
添加到切片 destSlice
中。循环直到所有的记录都添加到切片 destSlice
中。
4 测试 在 session 文件夹下新建 record_test.go
,创建测试用例。
User
和 NewSession()
的定义位于 raw_test.go 中。
day3-save-query/session/record_test.go
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 package sessionimport "testing" var ( user1 = &User{"Tom" , 18 } user2 = &User{"Sam" , 25 } user3 = &User{"Jack" , 25 } ) func testRecordInit (t *testing.T) *Session { t.Helper() s := NewSession().Model(&User{}) err1 := s.DropTable() err2 := s.CreateTable() _, err3 := s.Insert(user1, user2) if err1 != nil || err2 != nil || err3 != nil { t.Fatal("failed init test records" ) } return s } func TestSession_Insert (t *testing.T) { s := testRecordInit(t) affected, err := s.Insert(user3) if err != nil || affected != 1 { t.Fatal("failed to create record" ) } } func TestSession_Find (t *testing.T) { s := testRecordInit(t) var users []User if err := s.Find(&users); err != nil || len (users) != 2 { t.Fatal("failed to query all" ) } }
附 推荐阅读
last updated at 2023-11-15