♻️ Refactor code

🔥 Remove update detection
This commit is contained in:
2024-04-23 14:47:53 +08:00
parent ebc91d8aad
commit ac4ad3c8aa
16 changed files with 39 additions and 94 deletions

View File

@ -0,0 +1,56 @@
package database
import (
"path/filepath"
"sub2clash/common"
"sub2clash/logger"
"sub2clash/model"
"github.com/glebarez/sqlite"
"go.uber.org/zap"
"gorm.io/gorm"
)
var DB *gorm.DB
func ConnectDB() error {
// 用上面的数据库连接初始化 gorm
err := common.MKDir("data")
if err != nil {
return err
}
db, err := gorm.Open(
sqlite.Open(filepath.Join("data", "sub2clash.db")), &gorm.Config{
Logger: nil,
},
)
if err != nil {
return err
}
DB = db
err = db.AutoMigrate(&model.ShortLink{})
if err != nil {
return err
}
return nil
}
func FindShortLinkByUrl(url string, shortLink *model.ShortLink) *gorm.DB {
logger.Logger.Debug("find short link by url", zap.String("url", url))
return DB.Where("url = ?", url).First(&shortLink)
}
func FindShortLinkByHash(hash string, shortLink *model.ShortLink) *gorm.DB {
logger.Logger.Debug("find short link by hash", zap.String("hash", hash))
return DB.Where("hash = ?", hash).First(&shortLink)
}
func SaveShortLink(shortLink *model.ShortLink) {
logger.Logger.Debug("save short link", zap.String("hash", shortLink.Hash))
DB.Save(shortLink)
}
func FirstOrCreateShortLink(shortLink *model.ShortLink) {
logger.Logger.Debug("first or create short link", zap.String("hash", shortLink.Hash))
DB.FirstOrCreate(shortLink)
}

38
common/get.go Normal file
View File

@ -0,0 +1,38 @@
package common
import (
"errors"
"net/http"
"sub2clash/config"
"time"
)
func Get(url string) (resp *http.Response, err error) {
retryTimes := config.Default.RequestRetryTimes
haveTried := 0
retryDelay := time.Second // 延迟1秒再重试
for haveTried < retryTimes {
client := &http.Client{}
//client.Timeout = time.Second * 10
req, err := http.NewRequest("GET", url, nil)
if err != nil {
haveTried++
time.Sleep(retryDelay)
continue
}
get, err := client.Do(req)
if err != nil {
haveTried++
time.Sleep(retryDelay)
continue
} else {
// 如果文件大小大于设定,直接返回错误
if get != nil && get.ContentLength > config.Default.RequestMaxFileSize {
return nil, errors.New("文件过大")
}
return get, nil
}
}
return nil, err
}

30
common/mkdir.go Normal file
View File

@ -0,0 +1,30 @@
package common
import (
"errors"
"os"
)
func MKDir(dir string) error {
if _, err := os.Stat(dir); os.IsNotExist(err) {
err := os.MkdirAll(dir, os.ModePerm)
if err != nil {
return err
}
}
return nil
}
func MkEssentialDir() error {
if err := MKDir("subs"); err != nil {
return errors.New("create subs dir failed" + err.Error())
}
if err := MKDir("templates"); err != nil {
return errors.New("create templates dir failed" + err.Error())
}
if err := MKDir("logs"); err != nil {
return errors.New("create logs dir failed" + err.Error())
}
return nil
}

141
common/proxy.go Normal file
View File

@ -0,0 +1,141 @@
package common
import (
"strings"
"sub2clash/constant"
"sub2clash/logger"
"sub2clash/model"
"sub2clash/parser"
"go.uber.org/zap"
)
func GetContryName(countryKey string) string {
// 创建一个切片包含所有的国家映射
countryMaps := []map[string]string{
model.CountryFlag,
model.CountryChineseName,
model.CountryISO,
model.CountryEnglishName,
}
// 对每一个映射进行检查
for i, countryMap := range countryMaps {
if i == 2 {
// 对ISO匹配做特殊处理
// 根据常用分割字符分割字符串
splitChars := []string{"-", "_", " "}
key := make([]string, 0)
for _, splitChar := range splitChars {
slic := strings.Split(countryKey, splitChar)
for _, v := range slic {
if len(v) == 2 {
key = append(key, v)
}
}
}
// 对每一个分割后的字符串进行检查
for _, v := range key {
// 如果匹配到了国家
if country, ok := countryMap[strings.ToUpper(v)]; ok {
return country
}
}
}
for k, v := range countryMap {
if strings.Contains(countryKey, k) {
return v
}
}
}
return "其他地区"
}
func AddProxy(
sub *model.Subscription, autotest bool,
lazy bool, clashType model.ClashType, proxies ...model.Proxy,
) {
proxyTypes := model.GetSupportProxyTypes(clashType)
// 添加节点
for _, proxy := range proxies {
if !proxyTypes[proxy.Type] {
continue
}
sub.Proxies = append(sub.Proxies, proxy)
haveProxyGroup := false
countryName := GetContryName(proxy.Name)
for i := range sub.ProxyGroups {
group := &sub.ProxyGroups[i]
if group.Name == countryName {
group.Proxies = append(group.Proxies, proxy.Name)
group.Size++
haveProxyGroup = true
}
}
if !haveProxyGroup {
var newGroup model.ProxyGroup
if !autotest {
newGroup = model.ProxyGroup{
Name: countryName,
Type: "select",
Proxies: []string{proxy.Name},
IsCountryGrop: true,
Size: 1,
}
} else {
newGroup = model.ProxyGroup{
Name: countryName,
Type: "url-test",
Proxies: []string{proxy.Name},
IsCountryGrop: true,
Url: "http://www.gstatic.com/generate_204",
Interval: 300,
Tolerance: 50,
Lazy: lazy,
Size: 1,
}
}
sub.ProxyGroups = append(sub.ProxyGroups, newGroup)
}
}
}
func ParseProxy(proxies ...string) []model.Proxy {
var result []model.Proxy
for _, proxy := range proxies {
if proxy != "" {
var proxyItem model.Proxy
var err error
// 解析节点
if strings.HasPrefix(proxy, constant.ShadowsocksPrefix) {
proxyItem, err = parser.ParseShadowsocks(proxy)
}
if strings.HasPrefix(proxy, constant.TrojanPrefix) {
proxyItem, err = parser.ParseTrojan(proxy)
}
if strings.HasPrefix(proxy, constant.VMessPrefix) {
proxyItem, err = parser.ParseVmess(proxy)
}
if strings.HasPrefix(proxy, constant.VLESSPrefix) {
proxyItem, err = parser.ParseVless(proxy)
}
if strings.HasPrefix(proxy, constant.ShadowsocksRPrefix) {
proxyItem, err = parser.ParseShadowsocksR(proxy)
}
if strings.HasPrefix(proxy, constant.Hysteria2Prefix1) || strings.HasPrefix(proxy, constant.Hysteria2Prefix2) {
proxyItem, err = parser.ParseHysteria2(proxy)
}
if strings.HasPrefix(proxy, constant.HysteriaPrefix) {
proxyItem, err = parser.ParseHysteria(proxy)
}
if err == nil {
result = append(result, proxyItem)
} else {
logger.Logger.Debug(
"parse proxy failed", zap.String("proxy", proxy), zap.Error(err),
)
}
}
}
return result
}

13
common/random_string.go Normal file
View File

@ -0,0 +1,13 @@
package common
import "math/rand"
func RandomString(length int) string {
// 生成随机字符串
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
var result []byte
for i := 0; i < length; i++ {
result = append(result, charset[rand.Intn(len(charset))])
}
return string(result)
}

50
common/rule.go Normal file
View File

@ -0,0 +1,50 @@
package common
import (
"fmt"
"strings"
"sub2clash/model"
)
func PrependRuleProvider(
sub *model.Subscription, providerName string, group string, provider model.RuleProvider,
) {
if sub.RuleProviders == nil {
sub.RuleProviders = make(map[string]model.RuleProvider)
}
sub.RuleProviders[providerName] = provider
PrependRules(
sub,
fmt.Sprintf("RULE-SET,%s,%s", providerName, group),
)
}
func AppenddRuleProvider(
sub *model.Subscription, providerName string, group string, provider model.RuleProvider,
) {
if sub.RuleProviders == nil {
sub.RuleProviders = make(map[string]model.RuleProvider)
}
sub.RuleProviders[providerName] = provider
AppendRules(sub, fmt.Sprintf("RULE-SET,%s,%s", providerName, group))
}
func PrependRules(sub *model.Subscription, rules ...string) {
if sub.Rules == nil {
sub.Rules = make([]string, 0)
}
sub.Rules = append(rules, sub.Rules...)
}
func AppendRules(sub *model.Subscription, rules ...string) {
if sub.Rules == nil {
sub.Rules = make([]string, 0)
}
matchRule := sub.Rules[len(sub.Rules)-1]
if strings.Contains(matchRule, "MATCH") {
sub.Rules = append(sub.Rules[:len(sub.Rules)-1], rules...)
sub.Rules = append(sub.Rules, matchRule)
return
}
sub.Rules = append(sub.Rules, rules...)
}

85
common/sub.go Normal file
View File

@ -0,0 +1,85 @@
package common
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"os"
"path/filepath"
"sub2clash/config"
"sync"
"time"
)
var subsDir = "subs"
var fileLock sync.RWMutex
func LoadSubscription(url string, refresh bool) ([]byte, error) {
if refresh {
return FetchSubscriptionFromAPI(url)
}
hash := sha256.Sum224([]byte(url))
fileName := filepath.Join(subsDir, hex.EncodeToString(hash[:]))
stat, err := os.Stat(fileName)
if err != nil {
if !os.IsNotExist(err) {
return nil, err
}
return FetchSubscriptionFromAPI(url)
}
lastGetTime := stat.ModTime().Unix() // 单位是秒
if lastGetTime+config.Default.CacheExpire > time.Now().Unix() {
file, err := os.Open(fileName)
if err != nil {
return nil, err
}
defer func(file *os.File) {
if file != nil {
_ = file.Close()
}
}(file)
fileLock.RLock()
defer fileLock.RUnlock()
subContent, err := io.ReadAll(file)
if err != nil {
return nil, err
}
return subContent, nil
}
return FetchSubscriptionFromAPI(url)
}
func FetchSubscriptionFromAPI(url string) ([]byte, error) {
hash := sha256.Sum224([]byte(url))
fileName := filepath.Join(subsDir, hex.EncodeToString(hash[:]))
resp, err := Get(url)
if err != nil {
return nil, err
}
defer func(Body io.ReadCloser) {
if Body != nil {
_ = Body.Close()
}
}(resp.Body)
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
file, err := os.Create(fileName)
if err != nil {
return nil, err
}
defer func(file *os.File) {
if file != nil {
_ = file.Close()
}
}(file)
fileLock.Lock()
defer fileLock.Unlock()
_, err = file.Write(data)
if err != nil {
return nil, fmt.Errorf("failed to write to sub.yaml: %w", err)
}
return data, nil
}

31
common/template.go Normal file
View File

@ -0,0 +1,31 @@
package common
import (
"errors"
"io"
"os"
"path/filepath"
)
// LoadTemplate 加载模板
// templates 模板文件名
func LoadTemplate(template string) ([]byte, error) {
tPath := filepath.Join("templates", template)
if _, err := os.Stat(tPath); err == nil {
file, err := os.Open(tPath)
if err != nil {
return nil, err
}
defer func(file *os.File) {
if file != nil {
_ = file.Close()
}
}(file)
result, err := io.ReadAll(file)
if err != nil {
return nil, err
}
return result, nil
}
return nil, errors.New("模板文件不存在")
}

View File

@ -0,0 +1,39 @@
package common
import (
"os"
"path/filepath"
"sub2clash/config"
)
func writeTemplate(path string, template string) error {
tPath := filepath.Join(
"templates", path,
)
if _, err := os.Stat(tPath); os.IsNotExist(err) {
file, err := os.Create(tPath)
if err != nil {
return err
}
defer func(file *os.File) {
if file != nil {
_ = file.Close()
}
}(file)
_, err = file.WriteString(template)
if err != nil {
return err
}
}
return nil
}
func WriteDefalutTemplate(templateMeta string, templateClash string) error {
if err := writeTemplate(config.Default.MetaTemplate, templateMeta); err != nil {
return err
}
if err := writeTemplate(config.Default.ClashTemplate, templateClash); err != nil {
return err
}
return nil
}