package database import ( "database/sql" "fmt" "strings" "task-track-backend/internal/config" "task-track-backend/model" "github.com/go-sql-driver/mysql" gorm_mysql "gorm.io/driver/mysql" "gorm.io/gorm" "gorm.io/gorm/logger" ) var db *gorm.DB func Init(cfg config.DatabaseConfig) (*gorm.DB, error) { // 首先尝试连接到数据库 db, err := connectToDatabase(cfg) if err != nil { // 如果连接失败,检查是否是因为数据库不存在 if isDBNotExistError(err) { // 创建数据库 err = createDatabase(cfg) if err != nil { return nil, fmt.Errorf("failed to create database: %w", err) } // 重新尝试连接 db, err = connectToDatabase(cfg) if err != nil { return nil, fmt.Errorf("failed to connect to database after creation: %w", err) } } else { return nil, fmt.Errorf("failed to connect to database: %w", err) } } // 自动迁移数据库表 err = autoMigrateModels(db) if err != nil { return nil, fmt.Errorf("failed to auto migrate models: %w", err) } // 设置全局数据库实例 setDB(db) return db, nil } // GetDB 获取全局数据库实例 func GetDB() *gorm.DB { return db } // setDB 设置全局数据库实例 func setDB(database *gorm.DB) { db = database } func connectToDatabase(cfg config.DatabaseConfig) (*gorm.DB, error) { dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=%s&parseTime=True&loc=Local", cfg.Username, cfg.Password, cfg.Host, cfg.Port, cfg.Database, cfg.Charset, ) // 配置 GORM 连接选项 gormConfig := &gorm.Config{ Logger: logger.Default.LogMode(logger.Info), // 启用错误翻译,将数据库特定错误转换为 GORM 通用错误 TranslateError: true, } // 配置 MySQL 驱动选项 mysqlConfig := gorm_mysql.Config{ DSN: dsn, DefaultStringSize: 256, DisableDatetimePrecision: true, DontSupportRenameIndex: true, DontSupportRenameColumn: true, SkipInitializeWithVersion: false, } db, err := gorm.Open(gorm_mysql.New(mysqlConfig), gormConfig) if err != nil { return nil, err } // 配置连接池 sqlDB, err := db.DB() if err != nil { return nil, err } // 设置连接池参数 sqlDB.SetMaxIdleConns(10) sqlDB.SetMaxOpenConns(100) return db, nil } func createDatabase(cfg config.DatabaseConfig) error { // 创建不包含数据库名的 DSN,用于连接 MySQL 服务器 dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/?charset=%s&parseTime=True&loc=Local", cfg.Username, cfg.Password, cfg.Host, cfg.Port, cfg.Charset, ) // 直接使用 database/sql 连接 db, err := sql.Open("mysql", dsn) if err != nil { return fmt.Errorf("failed to connect to MySQL server: %w", err) } defer db.Close() // 测试连接 if err := db.Ping(); err != nil { return fmt.Errorf("failed to ping MySQL server: %w", err) } // 创建数据库 createDBSQL := fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `%s` CHARACTER SET %s COLLATE %s_general_ci", cfg.Database, cfg.Charset, cfg.Charset) _, err = db.Exec(createDBSQL) if err != nil { return fmt.Errorf("failed to create database %s: %w", cfg.Database, err) } return nil } func autoMigrateModels(db *gorm.DB) error { // 使用 GORM 的 AutoMigrate 自动迁移所有模型 // 设置表选项(MySQL 存储引擎) err := db.Set("gorm:table_options", "ENGINE=InnoDB").AutoMigrate( &model.User{}, &model.Organization{}, &model.UserOrganization{}, &model.Task{}, &model.TaskTag{}, &model.TaskTagRelation{}, &model.TaskComment{}, &model.TaskAttachment{}, ) if err != nil { return fmt.Errorf("auto migrate failed: %w", err) } return nil } func isDBNotExistError(err error) bool { if err == nil { return false } // 检查是否是 MySQL 错误 if mysqlErr, ok := err.(*mysql.MySQLError); ok { // MySQL 错误码 1049: Unknown database return mysqlErr.Number == 1049 } // 检查错误消息中是否包含数据库不存在的关键字 errMsg := strings.ToLower(err.Error()) return strings.Contains(errMsg, "unknown database") || strings.Contains(errMsg, "database doesn't exist") }