1
0
mirror of https://github.com/bestnite/sub2clash.git synced 2026-04-26 12:51:52 +00:00

refactor: preserve template yaml structure

This commit is contained in:
2026-04-25 23:22:36 +10:00
parent 2d863752b1
commit 9d23b11751
12 changed files with 1395 additions and 472 deletions
+93
View File
@@ -0,0 +1,93 @@
package common
import (
P "github.com/bestnite/sub2clash/model/proxy"
"golang.org/x/text/collate"
"golang.org/x/text/language"
)
// proxyListDoc 只用于解析 YAML 订阅中的 proxies 字段。
// 方案 A/B 下我们不再关心订阅 YAML 里的其他 mihomo 配置项。
type proxyListDoc struct {
Proxy []P.Proxy `yaml:"proxies,omitempty"`
}
// generatedConfig 是运行期的最小叠加模型:
// 只保留本项目真正会读取、生成或修改的字段。
//
// 这里承载的是“本项目的业务叠加层”,而不是 mihomo 的完整配置模型:
// - Proxy: 解析出的节点,用于过滤、去重、分组等中间处理
// - ProxyGroup: 模板中需要参与占位符展开的组,以及本项目生成的国家组
// - Rule: 模板规则 + 用户追加规则,用于保持 MATCH 规则前插入的语义
type generatedConfig struct {
Proxy []P.Proxy `yaml:"proxies,omitempty"`
ProxyGroup []generatedGroup `yaml:"proxy-groups,omitempty"`
Rule []string `yaml:"rules,omitempty"`
}
// generatedGroup 表示本项目生成出来的代理组最小模型,
// 它不再镜像 mihomo 的完整 proxy-group 配置结构。
//
// 这里只保留“当前逻辑真正需要读写的字段”:
// - Name / Proxies:用于模板占位符展开与 patch
// - Type / Url / Interval / Tolerance / Lazy:用于输出自动测速国家组
// - Size / IsCountry:仅作为运行期辅助信息,不参与 YAML 输出
type generatedGroup struct {
Type string `yaml:"type,omitempty"`
Name string `yaml:"name,omitempty"`
Proxies []string `yaml:"proxies,omitempty"`
Url string `yaml:"url,omitempty"`
Interval int `yaml:"interval,omitempty"`
Tolerance int `yaml:"tolerance,omitempty"`
Lazy bool `yaml:"lazy"`
Size int `yaml:"-"`
IsCountry bool `yaml:"-"`
}
// generatedRulePatch 表示本项目追加/覆盖的 rule-provider 最小模型。
// 它仅用于把用户请求转换成对 templateDoc 的字段级 patch。
type generatedRulePatch struct {
Type string `yaml:"type,omitempty"`
Behavior string `yaml:"behavior,omitempty"`
Url string `yaml:"url,omitempty"`
Path string `yaml:"path,omitempty"`
Interval int `yaml:"interval,omitempty"`
Format string `yaml:"format,omitempty"`
}
type generatedGroupsSortByName []generatedGroup
type generatedGroupsSortBySize []generatedGroup
func (p generatedGroupsSortByName) Len() int {
return len(p)
}
func (p generatedGroupsSortBySize) Len() int {
return len(p)
}
func (p generatedGroupsSortByName) Less(i, j int) bool {
tags := []language.Tag{
language.English,
language.Chinese,
}
matcher := language.NewMatcher(tags)
bestMatch, _, _ := matcher.Match(language.Make("zh"))
c := collate.New(bestMatch)
return c.CompareString(p[i].Name, p[j].Name) < 0
}
func (p generatedGroupsSortBySize) Less(i, j int) bool {
if p[i].Size == p[j].Size {
return p[i].Name < p[j].Name
}
return p[i].Size < p[j].Size
}
func (p generatedGroupsSortByName) Swap(i, j int) {
p[i], p[j] = p[j], p[i]
}
func (p generatedGroupsSortBySize) Swap(i, j int) {
p[i], p[j] = p[j], p[i]
}
+18 -18
View File
@@ -47,7 +47,7 @@ func GetContryName(countryKey string) string {
}
func AddProxy(
sub *model.Subscription, autotest bool,
sub *generatedConfig, autotest bool,
lazy bool, clashType model.ClashType, proxies ...proxy.Proxy,
) {
proxyTypes := model.GetSupportProxyTypes(clashType)
@@ -68,26 +68,26 @@ func AddProxy(
}
}
if !haveProxyGroup {
var newGroup model.ProxyGroup
var newGroup generatedGroup
if !autotest {
newGroup = model.ProxyGroup{
Name: countryName,
Type: "select",
Proxies: []string{proxy.Name},
IsCountryGrop: true,
Size: 1,
newGroup = generatedGroup{
Name: countryName,
Type: "select",
Proxies: []string{proxy.Name},
IsCountry: true,
Size: 1,
}
} else {
newGroup = model.ProxyGroup{
Name: countryName,
Type: "url-test",
Proxies: []string{proxy.Name},
IsCountryGrop: true,
Url: "http://www.gstatic.com/generate_204",
Interval: 300,
Tolerance: 50,
Lazy: lazy,
Size: 1,
newGroup = generatedGroup{
Name: countryName,
Type: "url-test",
Proxies: []string{proxy.Name},
IsCountry: true,
Url: "http://www.gstatic.com/generate_204",
Interval: 300,
Tolerance: 50,
Lazy: lazy,
Size: 1,
}
}
sub.ProxyGroup = append(sub.ProxyGroup, newGroup)
+11 -14
View File
@@ -3,17 +3,11 @@ package common
import (
"fmt"
"strings"
"github.com/bestnite/sub2clash/model"
)
func PrependRuleProvider(
sub *model.Subscription, providerName string, group string, provider model.RuleProvider,
sub *generatedConfig, providerName string, group string,
) {
if sub.RuleProvider == nil {
sub.RuleProvider = make(map[string]model.RuleProvider)
}
sub.RuleProvider[providerName] = provider
PrependRules(
sub,
fmt.Sprintf("RULE-SET,%s,%s", providerName, group),
@@ -21,26 +15,29 @@ func PrependRuleProvider(
}
func AppenddRuleProvider(
sub *model.Subscription, providerName string, group string, provider model.RuleProvider,
sub *generatedConfig, providerName string, group string,
) {
if sub.RuleProvider == nil {
sub.RuleProvider = make(map[string]model.RuleProvider)
}
sub.RuleProvider[providerName] = provider
AppendRules(sub, fmt.Sprintf("RULE-SET,%s,%s", providerName, group))
}
func PrependRules(sub *model.Subscription, rules ...string) {
// PrependRules 用于在规则头部插入新规则。
// 这通常对应用户显式要求 prepend 的场景。
func PrependRules(sub *generatedConfig, rules ...string) {
if sub.Rule == nil {
sub.Rule = make([]string, 0)
}
sub.Rule = append(rules, sub.Rule...)
}
func AppendRules(sub *model.Subscription, rules ...string) {
// AppendRules 在规则尾部追加,但如果尾部已有 MATCH,则保持 MATCH 仍然是最后一条。
func AppendRules(sub *generatedConfig, rules ...string) {
if sub.Rule == nil {
sub.Rule = make([]string, 0)
}
if len(sub.Rule) == 0 {
sub.Rule = append(sub.Rule, rules...)
return
}
matchRule := sub.Rule[len(sub.Rule)-1]
if strings.Contains(matchRule, "MATCH") {
sub.Rule = append(sub.Rule[:len(sub.Rule)-1], rules...)
+407 -120
View File
@@ -93,11 +93,85 @@ func FetchSubscriptionFromAPI(url string, userAgent string, retryTimes int) ([]b
return data, nil
}
// BuildSub 是当前配置转换链路的核心入口。
//
// 当前设计分为三层:
// 1. templateDoc:模板 YAML 的完整语法树,也是最终输出真源
// 2. generatedConfig:本项目运行期最小叠加层,只保存参与业务计算的字段
// 3. proxy.Proxy:节点解析后的 typed 模型,用于过滤、去重、重命名和输出
//
// 这个函数的目标不是“重建一整份 mihomo 配置”,而是:
// - 保留模板中绝大部分原始字段
// - 只对 proxies / proxy-groups / rules / rule-providers 做定点 patch
func BuildSub(clashType model.ClashType, query model.ConvertConfig, template string, cacheExpire int64, retryTimes int) (
*model.Subscription, error,
*BuiltSub, error,
) {
var temp = &model.Subscription{}
var sub = &model.Subscription{}
templateDoc, templateBytes, err := loadTemplateDocument(query, template, cacheExpire, retryTimes)
if err != nil {
return nil, err
}
temp, err := extractTemplateOverlay(templateDoc)
if err != nil {
logger.Logger.Debug("extract template overlay failed", zap.Error(err))
return nil, NewTemplateParseError(templateBytes, err)
}
proxyList, err := collectQueryProxies(query, cacheExpire, retryTimes)
if err != nil {
return nil, err
}
proxyList, err = normalizeProxyList(query, proxyList)
if err != nil {
return nil, err
}
// t 仅承载“由节点生成出来的新内容”,例如国家组。
// 模板里原有的组、规则等则保存在 temp 中。
generated, err := buildGeneratedConfig(clashType, query, proxyList)
if err != nil {
return nil, err
}
MergeSubAndTemplate(temp, generated, query.IgnoreCountryGrooup)
applyRulePatches(temp, query)
addedRuleProviders := buildRuleProviderPatches(query)
if err := mergeTemplateProxies(templateDoc, generated.Proxy); err != nil {
return nil, NewError(ErrConfigInvalid, "failed to update template path: proxies", err)
}
if temp.ProxyGroup == nil {
temp.ProxyGroup = make([]generatedGroup, 0)
}
if err := mergeTemplateProxyGroups(templateDoc, temp.ProxyGroup); err != nil {
return nil, NewError(ErrConfigInvalid, "failed to update template path: proxy-groups", err)
}
rulesChanged := len(query.Rules) != 0 || len(query.RuleProviders) != 0
if rulesChanged {
if temp.Rule == nil {
temp.Rule = make([]string, 0)
}
if err := SetYAMLPath(templateDoc, "rules", temp.Rule); err != nil {
return nil, NewError(ErrConfigInvalid, "failed to update template path: rules", err)
}
}
if len(query.RuleProviders) != 0 {
if err := mergeTemplateRuleProviders(templateDoc, addedRuleProviders); err != nil {
return nil, NewError(ErrConfigInvalid, "failed to update template path: rule-providers", err)
}
}
return &BuiltSub{root: templateDoc}, nil
}
// loadTemplateDocument 负责统一加载模板来源,并返回:
// 1. 解析后的 YAML 语法树
// 2. 原始模板字节,用于错误报告
func loadTemplateDocument(query model.ConvertConfig, template string, cacheExpire int64, retryTimes int) (*yaml.Node, []byte, error) {
var err error
var templateBytes []byte
@@ -110,79 +184,38 @@ func BuildSub(clashType model.ClashType, query model.ConvertConfig, template str
logger.Logger.Debug(
"load template failed", zap.String("template", template), zap.Error(err),
)
return nil, NewTemplateLoadError(template, err)
return nil, nil, NewTemplateLoadError(template, err)
}
} else {
unescape, err := url.QueryUnescape(template)
if err != nil {
return nil, NewTemplateLoadError(template, err)
return nil, nil, NewTemplateLoadError(template, err)
}
templateBytes, err = LoadTemplate(unescape)
if err != nil {
logger.Logger.Debug(
"load template failed", zap.String("template", template), zap.Error(err),
)
return nil, NewTemplateLoadError(unescape, err)
return nil, nil, NewTemplateLoadError(unescape, err)
}
}
err = yaml.Unmarshal(templateBytes, &temp)
templateDoc, err := ParseYAMLDocument(templateBytes)
if err != nil {
logger.Logger.Debug("parse template failed", zap.Error(err))
return nil, NewTemplateParseError(templateBytes, err)
logger.Logger.Debug("parse template yaml node failed", zap.Error(err))
return nil, templateBytes, NewTemplateParseError(templateBytes, err)
}
var proxyList []P.Proxy
return templateDoc, templateBytes, nil
}
// collectQueryProxies 汇总来自订阅链接和直接传入代理链接的所有节点。
func collectQueryProxies(query model.ConvertConfig, cacheExpire int64, retryTimes int) ([]P.Proxy, error) {
proxyList := make([]P.Proxy, 0)
for i := range query.Subs {
data, err := LoadSubscription(query.Subs[i], query.Refresh, query.UserAgent, cacheExpire, retryTimes)
newProxies, err := loadSubscriptionProxies(query, query.Subs[i], cacheExpire, retryTimes)
if err != nil {
logger.Logger.Debug(
"load subscription failed", zap.String("url", query.Subs[i]), zap.Error(err),
)
return nil, NewSubscriptionLoadError(query.Subs[i], err)
}
subName := ""
if strings.Contains(query.Subs[i], "#") {
subName = query.Subs[i][strings.LastIndex(query.Subs[i], "#")+1:]
}
err = yaml.Unmarshal(data, &sub)
var newProxies []P.Proxy
if err != nil {
reg, err := regexp.Compile("(" + strings.Join(parser.GetAllPrefixes(), "|") + ")://")
if err != nil {
logger.Logger.Debug("compile regex failed", zap.Error(err))
return nil, NewRegexInvalidError("prefix", err)
}
if reg.Match(data) {
p, err := parser.ParseProxies(parser.ParseConfig{UseUDP: query.UseUDP}, strings.Split(string(data), "\n")...)
if err != nil {
return nil, err
}
newProxies = p
} else {
base64, err := utils.DecodeBase64(string(data), false)
if err != nil {
logger.Logger.Debug(
"parse subscription failed", zap.String("url", query.Subs[i]),
zap.String("data", string(data)),
zap.Error(err),
)
return nil, NewSubscriptionParseError(data, err)
}
p, err := parser.ParseProxies(parser.ParseConfig{UseUDP: query.UseUDP}, strings.Split(base64, "\n")...)
if err != nil {
return nil, err
}
newProxies = p
}
} else {
newProxies = sub.Proxy
}
if subName != "" {
for i := range newProxies {
newProxies[i].SubName = subName
}
return nil, err
}
proxyList = append(proxyList, newProxies...)
}
@@ -195,13 +228,103 @@ func BuildSub(clashType model.ClashType, query model.ConvertConfig, template str
proxyList = append(proxyList, p...)
}
return proxyList, nil
}
// loadSubscriptionProxies 负责加载单条订阅并应用订阅名作为节点前缀。
func loadSubscriptionProxies(query model.ConvertConfig, subscriptionURL string, cacheExpire int64, retryTimes int) ([]P.Proxy, error) {
data, err := LoadSubscription(subscriptionURL, query.Refresh, query.UserAgent, cacheExpire, retryTimes)
if err != nil {
logger.Logger.Debug(
"load subscription failed", zap.String("url", subscriptionURL), zap.Error(err),
)
return nil, NewSubscriptionLoadError(subscriptionURL, err)
}
subName := ""
if strings.Contains(subscriptionURL, "#") {
subName = subscriptionURL[strings.LastIndex(subscriptionURL, "#")+1:]
}
newProxies, err := parseSubscriptionProxies(data, query.UseUDP, subscriptionURL)
if err != nil {
return nil, err
}
if subName != "" {
for i := range newProxies {
newProxies[i].SubName = subName
}
}
return newProxies, nil
}
// parseSubscriptionProxies 按“Clash YAML -> URI 列表 -> Base64 文本”的顺序容错解析节点。
func parseSubscriptionProxies(data []byte, useUDP bool, subscriptionURL string) ([]P.Proxy, error) {
sub := &proxyListDoc{}
if err := yaml.Unmarshal(data, sub); err == nil {
return sub.Proxy, nil
}
reg, err := regexp.Compile("(" + strings.Join(parser.GetAllPrefixes(), "|") + ")://")
if err != nil {
logger.Logger.Debug("compile regex failed", zap.Error(err))
return nil, NewRegexInvalidError("prefix", err)
}
if reg.Match(data) {
return parser.ParseProxies(parser.ParseConfig{UseUDP: useUDP}, strings.Split(string(data), "\n")...)
}
base64, err := utils.DecodeBase64(string(data), false)
if err != nil {
logger.Logger.Debug(
"parse subscription failed", zap.String("url", subscriptionURL),
zap.String("data", string(data)),
zap.Error(err),
)
return nil, NewSubscriptionParseError(data, err)
}
return parser.ParseProxies(parser.ParseConfig{UseUDP: useUDP}, strings.Split(base64, "\n")...)
}
// normalizeProxyList 汇总所有节点标准化步骤,确保后续分组和 patch 使用的是稳定结果。
func normalizeProxyList(query model.ConvertConfig, proxyList []P.Proxy) ([]P.Proxy, error) {
applySubscriptionPrefixes(proxyList)
var err error
proxyList, err = dedupeProxies(proxyList)
if err != nil {
return nil, err
}
proxyList, err = removeProxiesByPattern(proxyList, query.Remove)
if err != nil {
return nil, err
}
proxyList, err = replaceProxyNames(proxyList, query.Replace)
if err != nil {
return nil, err
}
ensureUniqueProxyNames(proxyList)
trimProxyNames(proxyList)
return proxyList, nil
}
func applySubscriptionPrefixes(proxyList []P.Proxy) {
for i := range proxyList {
if proxyList[i].SubName != "" {
proxyList[i].Name = strings.TrimSpace(proxyList[i].SubName) + " " + strings.TrimSpace(proxyList[i].Name)
}
}
}
// 去重
// dedupeProxies 通过 YAML 序列化结果判定两个节点是否完全相同。
func dedupeProxies(proxyList []P.Proxy) ([]P.Proxy, error) {
proxies := make(map[string]*P.Proxy)
newProxies := make([]P.Proxy, 0, len(proxyList))
for i := range proxyList {
@@ -216,45 +339,52 @@ func BuildSub(clashType model.ClashType, query model.ConvertConfig, template str
newProxies = append(newProxies, proxyList[i])
}
}
proxyList = newProxies
return newProxies, nil
}
// 移除
if strings.TrimSpace(query.Remove) != "" {
newProxyList := make([]P.Proxy, 0, len(proxyList))
func removeProxiesByPattern(proxyList []P.Proxy, pattern string) ([]P.Proxy, error) {
if strings.TrimSpace(pattern) == "" {
return proxyList, nil
}
removeReg, err := regexp.Compile(pattern)
if err != nil {
logger.Logger.Debug("remove regexp compile failed", zap.Error(err))
return nil, NewRegexInvalidError("remove", err)
}
newProxyList := make([]P.Proxy, 0, len(proxyList))
for i := range proxyList {
if removeReg.MatchString(proxyList[i].Name) {
continue
}
newProxyList = append(newProxyList, proxyList[i])
}
return newProxyList, nil
}
func replaceProxyNames(proxyList []P.Proxy, replacements map[string]string) ([]P.Proxy, error) {
if len(replacements) == 0 {
return proxyList, nil
}
for pattern, replacement := range replacements {
replaceReg, err := regexp.Compile(pattern)
if err != nil {
logger.Logger.Debug("replace regexp compile failed", zap.Error(err))
return nil, NewRegexInvalidError("replace", err)
}
for i := range proxyList {
removeReg, err := regexp.Compile(query.Remove)
if err != nil {
logger.Logger.Debug("remove regexp compile failed", zap.Error(err))
return nil, NewRegexInvalidError("remove", err)
}
if removeReg.MatchString(proxyList[i].Name) {
continue
}
newProxyList = append(newProxyList, proxyList[i])
}
proxyList = newProxyList
}
// 替换
if len(query.Replace) != 0 {
for k, v := range query.Replace {
replaceReg, err := regexp.Compile(k)
if err != nil {
logger.Logger.Debug("replace regexp compile failed", zap.Error(err))
return nil, NewRegexInvalidError("replace", err)
}
for i := range proxyList {
if replaceReg.MatchString(proxyList[i].Name) {
proxyList[i].Name = replaceReg.ReplaceAllString(
proxyList[i].Name, v,
)
}
if replaceReg.MatchString(proxyList[i].Name) {
proxyList[i].Name = replaceReg.ReplaceAllString(proxyList[i].Name, replacement)
}
}
}
// 重命名有相同名称的节点
return proxyList, nil
}
func ensureUniqueProxyNames(proxyList []P.Proxy) {
names := make(map[string]int)
for i := range proxyList {
if _, exist := names[proxyList[i].Name]; exist {
@@ -264,30 +394,39 @@ func BuildSub(clashType model.ClashType, query model.ConvertConfig, template str
names[proxyList[i].Name] = 0
}
}
}
func trimProxyNames(proxyList []P.Proxy) {
for i := range proxyList {
proxyList[i].Name = strings.TrimSpace(proxyList[i].Name)
}
}
var t = &model.Subscription{}
AddProxy(t, query.AutoTest, query.Lazy, clashType, proxyList...)
// buildGeneratedConfig 只生成“新增内容”,例如国家组和最终可输出的节点集合。
func buildGeneratedConfig(clashType model.ClashType, query model.ConvertConfig, proxyList []P.Proxy) (*generatedConfig, error) {
generated := &generatedConfig{}
AddProxy(generated, query.AutoTest, query.Lazy, clashType, proxyList...)
sortGeneratedGroups(generated, query.Sort)
return generated, nil
}
// 排序
switch query.Sort {
func sortGeneratedGroups(generated *generatedConfig, sortMode string) {
switch sortMode {
case "sizeasc":
sort.Sort(model.ProxyGroupsSortBySize(t.ProxyGroup))
sort.Sort(generatedGroupsSortBySize(generated.ProxyGroup))
case "sizedesc":
sort.Sort(sort.Reverse(model.ProxyGroupsSortBySize(t.ProxyGroup)))
sort.Sort(sort.Reverse(generatedGroupsSortBySize(generated.ProxyGroup)))
case "nameasc":
sort.Sort(model.ProxyGroupsSortByName(t.ProxyGroup))
sort.Sort(generatedGroupsSortByName(generated.ProxyGroup))
case "namedesc":
sort.Sort(sort.Reverse(model.ProxyGroupsSortByName(t.ProxyGroup)))
sort.Sort(sort.Reverse(generatedGroupsSortByName(generated.ProxyGroup)))
default:
sort.Sort(model.ProxyGroupsSortByName(t.ProxyGroup))
sort.Sort(generatedGroupsSortByName(generated.ProxyGroup))
}
}
MergeSubAndTemplate(temp, t, query.IgnoreCountryGrooup)
// applyRulePatches 只修改运行期 overlay 中的 rules 切片,不直接写 YAML。
func applyRulePatches(temp *generatedConfig, query model.ConvertConfig) {
for _, v := range query.Rules {
if v.Prepend {
PrependRules(temp, v.Rule)
@@ -295,28 +434,176 @@ func BuildSub(clashType model.ClashType, query model.ConvertConfig, template str
AppendRules(temp, v.Rule)
}
}
for _, v := range query.RuleProviders {
if v.Prepend {
PrependRuleProvider(temp, v.Name, v.Group)
} else {
AppenddRuleProvider(temp, v.Name, v.Group)
}
}
}
// buildRuleProviderPatches 把 API 请求中的 rule-provider 参数转换成 YAML patch payload。
func buildRuleProviderPatches(query model.ConvertConfig) map[string]generatedRulePatch {
if len(query.RuleProviders) == 0 {
return nil
}
patches := make(map[string]generatedRulePatch, len(query.RuleProviders))
for _, v := range query.RuleProviders {
hash := sha256.Sum224([]byte(v.Url))
name := hex.EncodeToString(hash[:])
provider := model.RuleProvider{
patches[v.Name] = generatedRulePatch{
Type: "http",
Behavior: v.Behavior,
Url: v.Url,
Path: "./" + name + ".yaml",
Interval: 3600,
}
if v.Prepend {
PrependRuleProvider(
temp, v.Name, v.Group, provider,
)
} else {
AppenddRuleProvider(
temp, v.Name, v.Group, provider,
)
}
return patches
}
// extractTemplateOverlay 只从模板 YAML 树中提取本项目真正会参与计算的局部字段。
// 这让模板读取完全基于 yaml.Node,而不再依赖任何整份配置的 typed unmarshal。
func extractTemplateOverlay(templateDoc *yaml.Node) (*generatedConfig, error) {
overlay := &generatedConfig{}
if err := decodeOptionalYAMLPath(templateDoc, "proxy-groups", &overlay.ProxyGroup); err != nil {
return nil, err
}
if err := decodeOptionalYAMLPath(templateDoc, "rules", &overlay.Rule); err != nil {
return nil, err
}
return overlay, nil
}
// decodeOptionalYAMLPath 在路径存在且非 null 时才执行 Decode
// 路径不存在时保持目标值为零值。
func decodeOptionalYAMLPath(doc *yaml.Node, path string, target any) error {
node, err := GetYAMLPath(doc, path)
if err != nil {
return err
}
if node == nil || isNullYAMLNode(node) {
return nil
}
if err := node.Decode(target); err != nil {
return fmt.Errorf("decode template path %q failed: %w", path, err)
}
return nil
}
// mergeTemplateProxies 只负责把本项目生成出的代理追加到模板现有 proxies 后面。
// 模板中已有代理节点原样保留,不做 struct round-trip。
func mergeTemplateProxies(templateDoc *yaml.Node, generated []P.Proxy) error {
if len(generated) == 0 && !HasYAMLPath(templateDoc, "proxies") {
return nil
}
proxiesNode, err := EnsureYAMLSequencePath(templateDoc, "proxies")
if err != nil {
return err
}
for _, proxy := range generated {
if err := AppendYAMLSequenceValue(proxiesNode, proxy); err != nil {
return err
}
}
return temp, nil
return nil
}
// mergeTemplateProxyGroups 负责两类更新:
// 1. 对模板中同名组,仅覆盖 proxies 字段,保留其他字段
// 2. 追加本项目新生成的国家组
func mergeTemplateProxyGroups(templateDoc *yaml.Node, groups []generatedGroup) error {
if len(groups) == 0 && !HasYAMLPath(templateDoc, "proxy-groups") {
return nil
}
groupNodes, err := EnsureYAMLSequencePath(templateDoc, "proxy-groups")
if err != nil {
return err
}
for _, group := range groups {
if group.IsCountry {
if existing := FindYAMLSequenceMappingByStringField(groupNodes, "name", group.Name); existing != nil {
continue
}
if err := AppendYAMLSequenceValue(groupNodes, group); err != nil {
return err
}
continue
}
existing := FindYAMLSequenceMappingByStringField(groupNodes, "name", group.Name)
if existing == nil {
if err := AppendYAMLSequenceValue(groupNodes, group); err != nil {
return err
}
continue
}
if findMappingValue(existing, "proxies") == nil {
continue
}
if err := SetYAMLMappingField(existing, "proxies", group.Proxies); err != nil {
return err
}
}
return nil
}
// mergeTemplateRuleProviders 以字段级 patch 的方式更新/插入 rule-provider
// 以避免覆盖模板中已有 provider 的未知字段。
func mergeTemplateRuleProviders(templateDoc *yaml.Node, providers map[string]generatedRulePatch) error {
if len(providers) == 0 && !HasYAMLPath(templateDoc, "rule-providers") {
return nil
}
providerNodes, err := EnsureYAMLMappingPath(templateDoc, "rule-providers")
if err != nil {
return err
}
for name, provider := range providers {
existing := findMappingValue(providerNodes, name)
if existing != nil && existing.Kind == yaml.MappingNode {
if err := SetYAMLMappingField(existing, "type", provider.Type); err != nil {
return err
}
if err := SetYAMLMappingField(existing, "behavior", provider.Behavior); err != nil {
return err
}
if err := SetYAMLMappingField(existing, "url", provider.Url); err != nil {
return err
}
if err := SetYAMLMappingField(existing, "path", provider.Path); err != nil {
return err
}
if err := SetYAMLMappingField(existing, "interval", provider.Interval); err != nil {
return err
}
if provider.Format != "" {
if err := SetYAMLMappingField(existing, "format", provider.Format); err != nil {
return err
}
}
continue
}
if err := SetYAMLMappingField(providerNodes, name, provider); err != nil {
return err
}
}
return nil
}
func FetchSubscriptionUserInfo(url string, userAgent string, retryTimes int) (string, error) {
@@ -336,10 +623,12 @@ func FetchSubscriptionUserInfo(url string, userAgent string, retryTimes int) (st
return "", NewNetworkResponseError("subscription-userinfo header not found", nil)
}
func MergeSubAndTemplate(temp *model.Subscription, sub *model.Subscription, igcg bool) {
// MergeSubAndTemplate 把“模板侧需要参与计算的最小叠加层”和“本项目生成结果”合并。
// 它只处理本项目关心的运行期结构,不负责最终 YAML 输出。
func MergeSubAndTemplate(temp *generatedConfig, sub *generatedConfig, igcg bool) {
var countryGroupNames []string
for _, proxyGroup := range sub.ProxyGroup {
if proxyGroup.IsCountryGrop {
if proxyGroup.IsCountry {
countryGroupNames = append(
countryGroupNames, proxyGroup.Name,
)
@@ -350,16 +639,14 @@ func MergeSubAndTemplate(temp *model.Subscription, sub *model.Subscription, igcg
proxyNames = append(proxyNames, proxy.Name)
}
temp.Proxy = append(temp.Proxy, sub.Proxy...)
for i := range temp.ProxyGroup {
if temp.ProxyGroup[i].IsCountryGrop {
if temp.ProxyGroup[i].IsCountry {
continue
}
newProxies := make([]string, 0)
countryGroupMap := make(map[string]model.ProxyGroup)
countryGroupMap := make(map[string]generatedGroup)
for _, v := range sub.ProxyGroup {
if v.IsCountryGrop {
if v.IsCountry {
countryGroupMap[v.Name] = v
}
}
+478
View File
@@ -0,0 +1,478 @@
package common
import (
"os"
"path/filepath"
"testing"
"github.com/bestnite/sub2clash/model"
"gopkg.in/yaml.v3"
)
func withRepoRoot(t *testing.T) {
t.Helper()
originalWD, err := os.Getwd()
if err != nil {
t.Fatalf("get working directory: %v", err)
}
repoRoot := filepath.Dir(originalWD)
if err := os.Chdir(repoRoot); err != nil {
t.Fatalf("change working directory: %v", err)
}
t.Cleanup(func() {
_ = os.Chdir(originalWD)
})
}
func TestBuildSubPreservesUnmodeledTemplateSections(t *testing.T) {
withRepoRoot(t)
templateName := "test_scheme_a_template.yaml"
templatePath := filepath.Join(templatesDir, templateName)
templateContent := `mixed-port: 7890
dns:
enable: true
future-field: true
new-section:
enabled: true
proxies:
proxy-groups:
- name: 节点选择
type: select
proxies:
- <countries>
- DIRECT
rules:
- MATCH,节点选择
`
if err := os.WriteFile(templatePath, []byte(templateContent), 0o644); err != nil {
t.Fatalf("write template: %v", err)
}
t.Cleanup(func() {
_ = os.Remove(templatePath)
})
result, err := BuildSub(model.Clash, model.ConvertConfig{
ClashType: model.Clash,
Proxies: []string{
"ss://YWVzLTI1Ni1nY206cGFzc3dvcmQ=@127.0.0.1:8080#Test Node",
},
}, templateName, 0, 0)
if err != nil {
t.Fatalf("build subscription: %v", err)
}
output, err := yaml.Marshal(result)
if err != nil {
t.Fatalf("marshal result: %v", err)
}
var doc map[string]any
if err := yaml.Unmarshal(output, &doc); err != nil {
t.Fatalf("unmarshal output: %v", err)
}
dns, ok := doc["dns"].(map[string]any)
if !ok {
t.Fatalf("dns section missing: %s", output)
}
if dns["future-field"] != true {
t.Fatalf("dns future-field not preserved: %#v", dns)
}
newSection, ok := doc["new-section"].(map[string]any)
if !ok {
t.Fatalf("new-section missing: %s", output)
}
if newSection["enabled"] != true {
t.Fatalf("new-section not preserved: %#v", newSection)
}
proxies, ok := doc["proxies"].([]any)
if !ok || len(proxies) != 1 {
t.Fatalf("expected generated proxies in output: %#v", doc["proxies"])
}
rules, ok := doc["rules"].([]any)
if !ok || len(rules) != 1 || rules[0] != "MATCH,节点选择" {
t.Fatalf("rules should stay untouched without rule patches: %#v", doc["rules"])
}
}
func TestBuildSubPreservesTemplateProxyAndGroupFields(t *testing.T) {
withRepoRoot(t)
templateName := "test_scheme_a_group_template.yaml"
templatePath := filepath.Join(templatesDir, templateName)
templateContent := `proxies:
- name: Template Proxy
type: ss
server: 1.1.1.1
port: 443
cipher: aes-256-gcm
password: password
future-proxy-field: keep
proxy-groups:
- name: 节点选择
type: select
future-group-field: keep
proxies:
- <countries>
- DIRECT
rules:
- MATCH,节点选择
`
if err := os.WriteFile(templatePath, []byte(templateContent), 0o644); err != nil {
t.Fatalf("write template: %v", err)
}
t.Cleanup(func() {
_ = os.Remove(templatePath)
})
result, err := BuildSub(model.Clash, model.ConvertConfig{
ClashType: model.Clash,
Proxies: []string{
"ss://YWVzLTI1Ni1nY206cGFzc3dvcmQ=@127.0.0.1:8080#Test Node",
},
}, templateName, 0, 0)
if err != nil {
t.Fatalf("build subscription: %v", err)
}
output, err := yaml.Marshal(result)
if err != nil {
t.Fatalf("marshal result: %v", err)
}
var doc map[string]any
if err := yaml.Unmarshal(output, &doc); err != nil {
t.Fatalf("unmarshal output: %v", err)
}
proxies, ok := doc["proxies"].([]any)
if !ok || len(proxies) != 2 {
t.Fatalf("expected two proxies in output: %#v", doc["proxies"])
}
firstProxy, ok := proxies[0].(map[string]any)
if !ok {
t.Fatalf("template proxy should remain a mapping: %#v", proxies[0])
}
if firstProxy["future-proxy-field"] != "keep" {
t.Fatalf("template proxy field not preserved: %#v", firstProxy)
}
groups, ok := doc["proxy-groups"].([]any)
if !ok || len(groups) == 0 {
t.Fatalf("expected proxy groups in output: %#v", doc["proxy-groups"])
}
firstGroup, ok := groups[0].(map[string]any)
if !ok {
t.Fatalf("template group should remain a mapping: %#v", groups[0])
}
if firstGroup["future-group-field"] != "keep" {
t.Fatalf("template proxy-group field not preserved: %#v", firstGroup)
}
groupProxies, ok := firstGroup["proxies"].([]any)
if !ok || len(groupProxies) == 0 {
t.Fatalf("template proxy-group proxies missing: %#v", firstGroup["proxies"])
}
for _, value := range groupProxies {
if value == "<countries>" {
t.Fatalf("placeholder should be resolved in template proxy-group: %#v", groupProxies)
}
}
}
func TestBuildSubAddsRulesForRuleProviderWhenTemplateHasNoRules(t *testing.T) {
withRepoRoot(t)
templateName := "test_scheme_a_rule_provider_template.yaml"
templatePath := filepath.Join(templatesDir, templateName)
templateContent := `proxy-groups:
- name: 节点选择
type: select
proxies:
- DIRECT
`
if err := os.WriteFile(templatePath, []byte(templateContent), 0o644); err != nil {
t.Fatalf("write template: %v", err)
}
t.Cleanup(func() {
_ = os.Remove(templatePath)
})
result, err := BuildSub(model.Clash, model.ConvertConfig{
ClashType: model.Clash,
Proxies: []string{
"ss://YWVzLTI1Ni1nY206cGFzc3dvcmQ=@127.0.0.1:8080#Test Node",
},
RuleProviders: []model.RuleProviderStruct{{
Name: "test-provider",
Group: "节点选择",
Behavior: "domain",
Url: "https://example.com/rules.yaml",
}},
}, templateName, 0, 0)
if err != nil {
t.Fatalf("build subscription: %v", err)
}
output, err := yaml.Marshal(result)
if err != nil {
t.Fatalf("marshal result: %v", err)
}
var doc map[string]any
if err := yaml.Unmarshal(output, &doc); err != nil {
t.Fatalf("unmarshal output: %v", err)
}
ruleProviders, ok := doc["rule-providers"].(map[string]any)
if !ok {
t.Fatalf("rule-providers missing: %#v", doc["rule-providers"])
}
if _, ok := ruleProviders["test-provider"]; !ok {
t.Fatalf("test-provider missing: %#v", ruleProviders)
}
rules, ok := doc["rules"].([]any)
if !ok || len(rules) != 1 || rules[0] != "RULE-SET,test-provider,节点选择" {
t.Fatalf("expected generated rule for provider: %#v", doc["rules"])
}
}
func TestBuildSubDoesNotInjectProxiesFieldIntoUseBasedGroup(t *testing.T) {
withRepoRoot(t)
templateName := "test_scheme_a_use_group_template.yaml"
templatePath := filepath.Join(templatesDir, templateName)
templateContent := `proxy-groups:
- name: 节点选择
type: select
use:
- provider-a
rules:
- MATCH,节点选择
`
if err := os.WriteFile(templatePath, []byte(templateContent), 0o644); err != nil {
t.Fatalf("write template: %v", err)
}
t.Cleanup(func() {
_ = os.Remove(templatePath)
})
result, err := BuildSub(model.Clash, model.ConvertConfig{
ClashType: model.Clash,
Proxies: []string{
"ss://YWVzLTI1Ni1nY206cGFzc3dvcmQ=@127.0.0.1:8080#Test Node",
},
}, templateName, 0, 0)
if err != nil {
t.Fatalf("build subscription: %v", err)
}
output, err := yaml.Marshal(result)
if err != nil {
t.Fatalf("marshal result: %v", err)
}
var doc map[string]any
if err := yaml.Unmarshal(output, &doc); err != nil {
t.Fatalf("unmarshal output: %v", err)
}
groups := doc["proxy-groups"].([]any)
firstGroup := groups[0].(map[string]any)
if _, exists := firstGroup["proxies"]; exists {
t.Fatalf("use-based group should not gain proxies field: %#v", firstGroup)
}
if _, exists := firstGroup["use"]; !exists {
t.Fatalf("use-based group should preserve use field: %#v", firstGroup)
}
}
func TestBuildSubPreservesUnknownFieldsOnExistingRuleProvider(t *testing.T) {
withRepoRoot(t)
templateName := "test_scheme_a_existing_provider_template.yaml"
templatePath := filepath.Join(templatesDir, templateName)
templateContent := `proxy-groups:
- name: 节点选择
type: select
proxies:
- DIRECT
rule-providers:
test-provider:
type: http
behavior: classical
url: https://old.example.com/rules.yaml
path: ./old.yaml
interval: 10
future-provider-field: keep
`
if err := os.WriteFile(templatePath, []byte(templateContent), 0o644); err != nil {
t.Fatalf("write template: %v", err)
}
t.Cleanup(func() {
_ = os.Remove(templatePath)
})
result, err := BuildSub(model.Clash, model.ConvertConfig{
ClashType: model.Clash,
Proxies: []string{
"ss://YWVzLTI1Ni1nY206cGFzc3dvcmQ=@127.0.0.1:8080#Test Node",
},
RuleProviders: []model.RuleProviderStruct{{
Name: "test-provider",
Group: "节点选择",
Behavior: "domain",
Url: "https://example.com/rules.yaml",
}},
}, templateName, 0, 0)
if err != nil {
t.Fatalf("build subscription: %v", err)
}
output, err := yaml.Marshal(result)
if err != nil {
t.Fatalf("marshal result: %v", err)
}
var doc map[string]any
if err := yaml.Unmarshal(output, &doc); err != nil {
t.Fatalf("unmarshal output: %v", err)
}
ruleProviders := doc["rule-providers"].(map[string]any)
provider := ruleProviders["test-provider"].(map[string]any)
if provider["future-provider-field"] != "keep" {
t.Fatalf("existing provider field not preserved: %#v", provider)
}
if provider["behavior"] != "domain" {
t.Fatalf("provider behavior not updated: %#v", provider)
}
if provider["url"] != "https://example.com/rules.yaml" {
t.Fatalf("provider url not updated: %#v", provider)
}
}
func TestBuildSubSkipsDuplicateCountryGroupNames(t *testing.T) {
withRepoRoot(t)
templateName := "test_scheme_a_country_group_template.yaml"
templatePath := filepath.Join(templatesDir, templateName)
templateContent := `proxy-groups:
- name: 其他地区
type: select
proxies:
- DIRECT
rules:
- MATCH,其他地区
`
if err := os.WriteFile(templatePath, []byte(templateContent), 0o644); err != nil {
t.Fatalf("write template: %v", err)
}
t.Cleanup(func() {
_ = os.Remove(templatePath)
})
result, err := BuildSub(model.Clash, model.ConvertConfig{
ClashType: model.Clash,
Proxies: []string{
"ss://YWVzLTI1Ni1nY206cGFzc3dvcmQ=@127.0.0.1:8080#UnknownCountryNode",
},
}, templateName, 0, 0)
if err != nil {
t.Fatalf("build subscription: %v", err)
}
output, err := yaml.Marshal(result)
if err != nil {
t.Fatalf("marshal result: %v", err)
}
var doc map[string]any
if err := yaml.Unmarshal(output, &doc); err != nil {
t.Fatalf("unmarshal output: %v", err)
}
groups := doc["proxy-groups"].([]any)
count := 0
for _, item := range groups {
group := item.(map[string]any)
if group["name"] == "其他地区" {
count++
}
}
if count != 1 {
t.Fatalf("expected duplicate country group names to be skipped, got %d entries: %#v", count, groups)
}
}
func TestBuiltSubMarshalNodeListYAMLUsesFinalYAMLTree(t *testing.T) {
withRepoRoot(t)
templateName := "test_scheme_a_nodelist_template.yaml"
templatePath := filepath.Join(templatesDir, templateName)
templateContent := `proxies:
- name: Template Proxy
type: ss
server: 1.1.1.1
port: 443
cipher: aes-256-gcm
password: password
future-proxy-field: keep
proxy-groups:
- name: 节点选择
type: select
proxies:
- DIRECT
`
if err := os.WriteFile(templatePath, []byte(templateContent), 0o644); err != nil {
t.Fatalf("write template: %v", err)
}
t.Cleanup(func() {
_ = os.Remove(templatePath)
})
result, err := BuildSub(model.Clash, model.ConvertConfig{
ClashType: model.Clash,
Proxies: []string{
"ss://YWVzLTI1Ni1nY206cGFzc3dvcmQ=@127.0.0.1:8080#Generated Node",
},
}, templateName, 0, 0)
if err != nil {
t.Fatalf("build subscription: %v", err)
}
output, err := result.MarshalNodeListYAML()
if err != nil {
t.Fatalf("marshal node list: %v", err)
}
var doc map[string]any
if err := yaml.Unmarshal(output, &doc); err != nil {
t.Fatalf("unmarshal output: %v", err)
}
proxies, ok := doc["proxies"].([]any)
if !ok || len(proxies) != 2 {
t.Fatalf("expected node list to include template and generated proxies: %#v", doc["proxies"])
}
firstProxy, ok := proxies[0].(map[string]any)
if !ok {
t.Fatalf("template proxy should remain a mapping: %#v", proxies[0])
}
if firstProxy["future-proxy-field"] != "keep" {
t.Fatalf("node list should be built from final yaml tree: %#v", firstProxy)
}
}
+383
View File
@@ -0,0 +1,383 @@
package common
import (
"fmt"
"strings"
"gopkg.in/yaml.v3"
)
// BuiltSub 保存最终输出所需的完整 YAML 树。
//
// 这里刻意不再保存整份 typed 配置副本:
// - root 是整个转换流程的最终产物
// - 所有常规输出都直接从 root 序列化
// - nodeList 模式也从 root 中提取 proxies,而不是依赖额外状态
type BuiltSub struct {
root *yaml.Node
}
// MarshalYAML 让 BuiltSub 在输出时直接复用 patch 后的 YAML 树,
// 从而避免再次经过 struct round-trip 丢失未知字段。
func (b *BuiltSub) MarshalYAML() (any, error) {
if b == nil || b.root == nil {
return nil, nil
}
if b.root.Kind == yaml.DocumentNode {
if len(b.root.Content) == 0 {
return nil, nil
}
return b.root.Content[0], nil
}
return b.root, nil
}
// MarshalNodeListYAML 从最终 YAML 树中提取 proxies 节点,构造 nodeList 模式输出。
// 这样 nodeList 也直接复用最终 root,而不是依赖额外的 typed struct 副本。
func (b *BuiltSub) MarshalNodeListYAML() ([]byte, error) {
if b == nil || b.root == nil {
return yaml.Marshal(&yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"})
}
proxiesNode, err := GetYAMLPath(b.root, "proxies")
if err != nil {
return nil, err
}
root := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
if proxiesNode != nil && !isNullYAMLNode(proxiesNode) {
setMappingValue(root, "proxies", cloneYAMLNode(proxiesNode))
}
return yaml.Marshal(root)
}
// ParseYAMLDocument 把原始 YAML 解析成 DocumentNode
// 并确保根内容最终是一个可写入的 mapping 节点。
func ParseYAMLDocument(data []byte) (*yaml.Node, error) {
var doc yaml.Node
if err := yaml.Unmarshal(data, &doc); err != nil {
return nil, err
}
if _, err := rootMappingNode(&doc); err != nil {
return nil, err
}
return &doc, nil
}
// HasYAMLPath 判断某个点路径是否存在。
// 这里仅关心“是否找到节点”,不关心节点具体类型。
func HasYAMLPath(doc *yaml.Node, path string) bool {
current, err := GetYAMLPath(doc, path)
return err == nil && current != nil
}
// GetYAMLPath 按 a.b.c 这种点路径向下查找节点。
// 当前实现只支持 mapping 之间的逐层下钻,不处理数组索引路径。
func GetYAMLPath(doc *yaml.Node, path string) (*yaml.Node, error) {
segments := splitYAMLPath(path)
if len(segments) == 0 {
return nil, fmt.Errorf("yaml path is empty")
}
current, err := rootMappingNode(doc)
if err != nil {
return nil, err
}
for _, segment := range segments {
next := findMappingValue(current, segment)
if next == nil {
return nil, nil
}
current = next
}
return current, nil
}
// SetYAMLPath 按点路径写入一个值;不存在的中间层会自动补成 mapping。
// 例如 a.b.c=1 会在缺失时依次创建 a 和 b 两层对象节点。
func SetYAMLPath(doc *yaml.Node, path string, value any) error {
segments := splitYAMLPath(path)
if len(segments) == 0 {
return fmt.Errorf("yaml path is empty")
}
current, err := rootMappingNode(doc)
if err != nil {
return err
}
for idx, segment := range segments[:len(segments)-1] {
next := findMappingValue(current, segment)
if next == nil {
next = &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
setMappingValue(current, segment, next)
}
if next.Kind != yaml.MappingNode {
return fmt.Errorf("yaml path %q segment %q is not a mapping", path, strings.Join(segments[:idx+1], "."))
}
current = next
}
encoded, err := encodeYAMLNode(value)
if err != nil {
return err
}
setMappingValue(current, segments[len(segments)-1], encoded)
return nil
}
// EnsureYAMLSequencePath 确保某个路径最终是 sequence(YAML 数组)节点。
// 不存在时会自动创建,已存在但类型不匹配时返回错误。
func EnsureYAMLSequencePath(doc *yaml.Node, path string) (*yaml.Node, error) {
return ensureYAMLPathKind(doc, path, yaml.SequenceNode, "!!seq")
}
// EnsureYAMLMappingPath 确保某个路径最终是 mapping(YAML 对象)节点。
func EnsureYAMLMappingPath(doc *yaml.Node, path string) (*yaml.Node, error) {
return ensureYAMLPathKind(doc, path, yaml.MappingNode, "!!map")
}
// SetYAMLMappingField 在一个 mapping 节点里设置单个字段。
// 它等价于“在当前对象上写 key: value”。
func SetYAMLMappingField(node *yaml.Node, key string, value any) error {
if node == nil || node.Kind != yaml.MappingNode {
return fmt.Errorf("yaml node is not a mapping")
}
encoded, err := encodeYAMLNode(value)
if err != nil {
return err
}
setMappingValue(node, key, encoded)
return nil
}
// AppendYAMLSequenceValue 向 sequence 节点末尾追加一个元素。
func AppendYAMLSequenceValue(node *yaml.Node, value any) error {
if node == nil || node.Kind != yaml.SequenceNode {
return fmt.Errorf("yaml node is not a sequence")
}
encoded, err := encodeYAMLNode(value)
if err != nil {
return err
}
node.Content = append(node.Content, encoded)
return nil
}
// FindYAMLSequenceMappingByStringField 在 YAML 数组中查找一个对象元素,
// 要求该对象存在指定字段且字段值等于目标字符串。
//
// 例如在 proxy-groups 里按 name 查找:
// - name: 节点选择
// type: select
func FindYAMLSequenceMappingByStringField(node *yaml.Node, field string, value string) *yaml.Node {
if node == nil || node.Kind != yaml.SequenceNode {
return nil
}
for _, item := range node.Content {
if item == nil || item.Kind != yaml.MappingNode {
continue
}
fieldNode := findMappingValue(item, field)
if fieldNode == nil || fieldNode.Kind != yaml.ScalarNode {
continue
}
if fieldNode.Value == value {
return item
}
}
return nil
}
// splitYAMLPath 把 a.b.c 这种点路径拆成 [a b c]。
// 空片段会被忽略,避免出现连续点号时产生无意义路径段。
func splitYAMLPath(path string) []string {
parts := strings.Split(path, ".")
segments := make([]string, 0, len(parts))
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
segments = append(segments, part)
}
return segments
}
// ensureYAMLPathKind 是 EnsureYAMLSequencePath / EnsureYAMLMappingPath 的底层实现。
// 它会:
// 1. 逐层确保中间节点存在且都是 mapping
// 2. 确保最后一个节点存在,且类型符合预期
func ensureYAMLPathKind(doc *yaml.Node, path string, kind yaml.Kind, tag string) (*yaml.Node, error) {
segments := splitYAMLPath(path)
if len(segments) == 0 {
return nil, fmt.Errorf("yaml path is empty")
}
current, err := rootMappingNode(doc)
if err != nil {
return nil, err
}
// 跳过最后一个元素在后面处理
for idx, segment := range segments[:len(segments)-1] {
next := findMappingValue(current, segment)
if next == nil || isNullYAMLNode(next) {
next = &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
setMappingValue(current, segment, next)
}
if next.Kind != yaml.MappingNode {
return nil, fmt.Errorf("yaml path %q segment %q is not a mapping", path, strings.Join(segments[:idx+1], "."))
}
current = next
}
lastSegment := segments[len(segments)-1]
node := findMappingValue(current, lastSegment)
if node == nil || isNullYAMLNode(node) {
node = &yaml.Node{Kind: kind, Tag: tag}
setMappingValue(current, lastSegment, node)
}
if node.Kind != kind {
return nil, fmt.Errorf("yaml path %q is not a %s", path, yamlKindName(kind))
}
return node, nil
}
// rootMappingNode 统一把“文档根”整理成一个可操作的 mapping 节点。
//
// yaml.v3 通常把整份 YAML 包在 DocumentNode 下,真正的内容位于 Content[0]。
// 当前项目的 patch 逻辑都假定最外层是 key-value 结构,因此这里会:
// 1. 处理空文档
// 2. 取出 DocumentNode 的实际根内容
// 3. 确保该根内容是 mapping
func rootMappingNode(doc *yaml.Node) (*yaml.Node, error) {
if doc == nil {
return nil, fmt.Errorf("yaml document is nil")
}
root := doc
if doc.Kind == 0 {
doc.Kind = yaml.DocumentNode
}
if doc.Kind == yaml.DocumentNode {
if len(doc.Content) == 0 {
doc.Content = append(doc.Content, &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"})
}
root = doc.Content[0]
}
if root.Kind == 0 {
root.Kind = yaml.MappingNode
root.Tag = "!!map"
}
if root.Kind != yaml.MappingNode {
return nil, fmt.Errorf("yaml root must be a mapping node")
}
return root, nil
}
// isNullYAMLNode 判断一个节点是否为空/未初始化/null。
// 这让我们在“路径不存在”和“路径存在但值为 null”时都能按缺失处理。
func isNullYAMLNode(node *yaml.Node) bool {
if node == nil {
return true
}
if node.Kind == 0 {
return true
}
return node.Kind == yaml.ScalarNode && node.Tag == "!!null"
}
// yamlKindName 仅用于生成更可读的错误信息。
func yamlKindName(kind yaml.Kind) string {
switch kind {
case yaml.MappingNode:
return "mapping"
case yaml.SequenceNode:
return "sequence"
case yaml.ScalarNode:
return "scalar"
case yaml.DocumentNode:
return "document"
default:
return "node"
}
}
// findMappingValue 在 mapping 节点中按 key 查找对应的 value 节点。
//
// 需要注意:yaml.v3 的 MappingNode.Content 不是 map,而是交替存储:
// [key1, value1, key2, value2, ...]
// 所以这里每次 idx += 2,依次跳过一个完整的 key-value 对。
func findMappingValue(node *yaml.Node, key string) *yaml.Node {
if node == nil || node.Kind != yaml.MappingNode {
return nil
}
for idx := 0; idx+1 < len(node.Content); idx += 2 {
if node.Content[idx].Value == key {
return node.Content[idx+1]
}
}
return nil
}
// setMappingValue 在 mapping 节点中设置 key 对应的 value。
// 如果 key 已存在,就原位替换;否则在末尾追加一组新的 key-value。
func setMappingValue(node *yaml.Node, key string, value *yaml.Node) {
for idx := 0; idx+1 < len(node.Content); idx += 2 {
if node.Content[idx].Value == key {
node.Content[idx+1] = value
return
}
}
node.Content = append(node.Content,
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: key},
value,
)
}
// encodeYAMLNode 把普通 Go 值编码成 *yaml.Node,方便统一塞回 YAML 树。
// 如果 Encode 产生的是 DocumentNode,这里会自动取出它的实际内容节点。
func encodeYAMLNode(value any) (*yaml.Node, error) {
var node yaml.Node
if err := node.Encode(value); err != nil {
return nil, err
}
if node.Kind == yaml.DocumentNode {
if len(node.Content) == 0 {
return &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}, nil
}
return node.Content[0], nil
}
return &node, nil
}
// cloneYAMLNode 深拷贝一个节点树,避免把同一个子树同时挂到多个输出根下。
func cloneYAMLNode(node *yaml.Node) *yaml.Node {
if node == nil {
return nil
}
clone := *node
if len(node.Content) != 0 {
clone.Content = make([]*yaml.Node, len(node.Content))
for i := range node.Content {
clone.Content[i] = cloneYAMLNode(node.Content[i])
}
}
return &clone
}