You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
161 lines
3.8 KiB
161 lines
3.8 KiB
package database
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"strings"
|
|
"task-track-backend/internal/config"
|
|
"task-track-backend/model"
|
|
|
|
"github.com/go-sql-driver/mysql"
|
|
gorm_mysql "gorm.io/driver/mysql"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/logger"
|
|
)
|
|
|
|
func Init(cfg config.DatabaseConfig) (*gorm.DB, error) {
|
|
// 首先尝试连接到数据库
|
|
db, err := connectToDatabase(cfg)
|
|
if err != nil {
|
|
// 如果连接失败,检查是否是因为数据库不存在
|
|
if isDBNotExistError(err) {
|
|
// 创建数据库
|
|
err = createDatabase(cfg)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create database: %w", err)
|
|
}
|
|
|
|
// 重新尝试连接
|
|
db, err = connectToDatabase(cfg)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to connect to database after creation: %w", err)
|
|
}
|
|
} else {
|
|
return nil, fmt.Errorf("failed to connect to database: %w", err)
|
|
}
|
|
}
|
|
|
|
// 自动迁移数据库表
|
|
err = autoMigrateModels(db)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to auto migrate models: %w", err)
|
|
}
|
|
|
|
return db, nil
|
|
}
|
|
|
|
func connectToDatabase(cfg config.DatabaseConfig) (*gorm.DB, error) {
|
|
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=%s&parseTime=True&loc=Local",
|
|
cfg.Username,
|
|
cfg.Password,
|
|
cfg.Host,
|
|
cfg.Port,
|
|
cfg.Database,
|
|
cfg.Charset,
|
|
)
|
|
|
|
// 配置 GORM 连接选项
|
|
gormConfig := &gorm.Config{
|
|
Logger: logger.Default.LogMode(logger.Info),
|
|
// 启用错误翻译,将数据库特定错误转换为 GORM 通用错误
|
|
TranslateError: true,
|
|
}
|
|
|
|
// 配置 MySQL 驱动选项
|
|
mysqlConfig := gorm_mysql.Config{
|
|
DSN: dsn,
|
|
DefaultStringSize: 256,
|
|
DisableDatetimePrecision: true,
|
|
DontSupportRenameIndex: true,
|
|
DontSupportRenameColumn: true,
|
|
SkipInitializeWithVersion: false,
|
|
}
|
|
|
|
db, err := gorm.Open(gorm_mysql.New(mysqlConfig), gormConfig)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 配置连接池
|
|
sqlDB, err := db.DB()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 设置连接池参数
|
|
sqlDB.SetMaxIdleConns(10)
|
|
sqlDB.SetMaxOpenConns(100)
|
|
|
|
return db, nil
|
|
}
|
|
|
|
func createDatabase(cfg config.DatabaseConfig) error {
|
|
// 创建不包含数据库名的 DSN,用于连接 MySQL 服务器
|
|
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/?charset=%s&parseTime=True&loc=Local",
|
|
cfg.Username,
|
|
cfg.Password,
|
|
cfg.Host,
|
|
cfg.Port,
|
|
cfg.Charset,
|
|
)
|
|
|
|
// 直接使用 database/sql 连接
|
|
db, err := sql.Open("mysql", dsn)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to connect to MySQL server: %w", err)
|
|
}
|
|
defer db.Close()
|
|
|
|
// 测试连接
|
|
if err := db.Ping(); err != nil {
|
|
return fmt.Errorf("failed to ping MySQL server: %w", err)
|
|
}
|
|
|
|
// 创建数据库
|
|
createDBSQL := fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `%s` CHARACTER SET %s COLLATE %s_general_ci",
|
|
cfg.Database, cfg.Charset, cfg.Charset)
|
|
|
|
_, err = db.Exec(createDBSQL)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create database %s: %w", cfg.Database, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func autoMigrateModels(db *gorm.DB) error {
|
|
// 使用 GORM 的 AutoMigrate 自动迁移所有模型
|
|
// 设置表选项(MySQL 存储引擎)
|
|
err := db.Set("gorm:table_options", "ENGINE=InnoDB").AutoMigrate(
|
|
&model.User{},
|
|
&model.Organization{},
|
|
&model.UserOrganization{},
|
|
&model.Task{},
|
|
&model.TaskTag{},
|
|
&model.TaskTagRelation{},
|
|
&model.TaskComment{},
|
|
&model.TaskAttachment{},
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("auto migrate failed: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func isDBNotExistError(err error) bool {
|
|
if err == nil {
|
|
return false
|
|
}
|
|
|
|
// 检查是否是 MySQL 错误
|
|
if mysqlErr, ok := err.(*mysql.MySQLError); ok {
|
|
// MySQL 错误码 1049: Unknown database
|
|
return mysqlErr.Number == 1049
|
|
}
|
|
|
|
// 检查错误消息中是否包含数据库不存在的关键字
|
|
errMsg := strings.ToLower(err.Error())
|
|
return strings.Contains(errMsg, "unknown database") ||
|
|
strings.Contains(errMsg, "database doesn't exist")
|
|
}
|
|
|