You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

161 lines
3.8 KiB

package database
import (
"database/sql"
"fmt"
"strings"
"task-track-backend/internal/config"
"task-track-backend/model"
"github.com/go-sql-driver/mysql"
gorm_mysql "gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
func Init(cfg config.DatabaseConfig) (*gorm.DB, error) {
// 首先尝试连接到数据库
db, err := connectToDatabase(cfg)
if err != nil {
// 如果连接失败,检查是否是因为数据库不存在
if isDBNotExistError(err) {
// 创建数据库
err = createDatabase(cfg)
if err != nil {
return nil, fmt.Errorf("failed to create database: %w", err)
}
// 重新尝试连接
db, err = connectToDatabase(cfg)
if err != nil {
return nil, fmt.Errorf("failed to connect to database after creation: %w", err)
}
} else {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
}
// 自动迁移数据库表
err = autoMigrateModels(db)
if err != nil {
return nil, fmt.Errorf("failed to auto migrate models: %w", err)
}
return db, nil
}
func connectToDatabase(cfg config.DatabaseConfig) (*gorm.DB, error) {
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=%s&parseTime=True&loc=Local",
cfg.Username,
cfg.Password,
cfg.Host,
cfg.Port,
cfg.Database,
cfg.Charset,
)
// 配置 GORM 连接选项
gormConfig := &gorm.Config{
Logger: logger.Default.LogMode(logger.Info),
// 启用错误翻译,将数据库特定错误转换为 GORM 通用错误
TranslateError: true,
}
// 配置 MySQL 驱动选项
mysqlConfig := gorm_mysql.Config{
DSN: dsn,
DefaultStringSize: 256,
DisableDatetimePrecision: true,
DontSupportRenameIndex: true,
DontSupportRenameColumn: true,
SkipInitializeWithVersion: false,
}
db, err := gorm.Open(gorm_mysql.New(mysqlConfig), gormConfig)
if err != nil {
return nil, err
}
// 配置连接池
sqlDB, err := db.DB()
if err != nil {
return nil, err
}
// 设置连接池参数
sqlDB.SetMaxIdleConns(10)
sqlDB.SetMaxOpenConns(100)
return db, nil
}
func createDatabase(cfg config.DatabaseConfig) error {
// 创建不包含数据库名的 DSN,用于连接 MySQL 服务器
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/?charset=%s&parseTime=True&loc=Local",
cfg.Username,
cfg.Password,
cfg.Host,
cfg.Port,
cfg.Charset,
)
// 直接使用 database/sql 连接
db, err := sql.Open("mysql", dsn)
if err != nil {
return fmt.Errorf("failed to connect to MySQL server: %w", err)
}
defer db.Close()
// 测试连接
if err := db.Ping(); err != nil {
return fmt.Errorf("failed to ping MySQL server: %w", err)
}
// 创建数据库
createDBSQL := fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `%s` CHARACTER SET %s COLLATE %s_general_ci",
cfg.Database, cfg.Charset, cfg.Charset)
_, err = db.Exec(createDBSQL)
if err != nil {
return fmt.Errorf("failed to create database %s: %w", cfg.Database, err)
}
return nil
}
func autoMigrateModels(db *gorm.DB) error {
// 使用 GORM 的 AutoMigrate 自动迁移所有模型
// 设置表选项(MySQL 存储引擎)
err := db.Set("gorm:table_options", "ENGINE=InnoDB").AutoMigrate(
&model.User{},
&model.Organization{},
&model.UserOrganization{},
&model.Task{},
&model.TaskTag{},
&model.TaskTagRelation{},
&model.TaskComment{},
&model.TaskAttachment{},
)
if err != nil {
return fmt.Errorf("auto migrate failed: %w", err)
}
return nil
}
func isDBNotExistError(err error) bool {
if err == nil {
return false
}
// 检查是否是 MySQL 错误
if mysqlErr, ok := err.(*mysql.MySQLError); ok {
// MySQL 错误码 1049: Unknown database
return mysqlErr.Number == 1049
}
// 检查错误消息中是否包含数据库不存在的关键字
errMsg := strings.ToLower(err.Error())
return strings.Contains(errMsg, "unknown database") ||
strings.Contains(errMsg, "database doesn't exist")
}