Skip to content

Commit 366f107

Browse files
Copilotlpcox
andauthored
Refactor shared DIFC decisions and logger level wrappers
Agent-Logs-Url: https://github.com/github/gh-aw-mcpg/sessions/97151d73-796a-42a0-9fb4-c0de9f160c32 Co-authored-by: lpcox <[email protected]>
1 parent cf172de commit 366f107

8 files changed

Lines changed: 118 additions & 66 deletions

File tree

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package difc
2+
3+
// ShouldBypassCoarseDeny returns true when a coarse-grained deny should still
4+
// proceed to backend execution so Phase 5 can enforce per-item policy.
5+
func ShouldBypassCoarseDeny(operation OperationType) bool {
6+
return operation == OperationRead
7+
}
8+
9+
// ShouldCallLabelResponse returns true when guards should label response data
10+
// for possible fine-grained filtering.
11+
func ShouldCallLabelResponse(operation OperationType, enforcementMode EnforcementMode) bool {
12+
isPureWrite := operation == OperationWrite
13+
return !isPureWrite && (operation != OperationReadWrite || enforcementMode != EnforcementStrict)
14+
}
15+
16+
// ShouldBlockFilteredResponse returns true when filtered items should block the
17+
// whole response instead of returning a partially filtered result.
18+
func ShouldBlockFilteredResponse(enforcementMode EnforcementMode, filteredCount int) bool {
19+
return enforcementMode == EnforcementStrict && filteredCount > 0
20+
}
21+
22+
// ShouldAccumulateReadLabels returns true when read labels should be
23+
// accumulated back into the agent label set.
24+
func ShouldAccumulateReadLabels(operation OperationType, enforcementMode EnforcementMode) bool {
25+
return operation != OperationWrite && enforcementMode == EnforcementPropagate
26+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package difc
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
)
8+
9+
func TestShouldBypassCoarseDeny(t *testing.T) {
10+
assert.True(t, ShouldBypassCoarseDeny(OperationRead))
11+
assert.False(t, ShouldBypassCoarseDeny(OperationWrite))
12+
assert.False(t, ShouldBypassCoarseDeny(OperationReadWrite))
13+
}
14+
15+
func TestShouldCallLabelResponse(t *testing.T) {
16+
assert.False(t, ShouldCallLabelResponse(OperationWrite, EnforcementStrict))
17+
assert.False(t, ShouldCallLabelResponse(OperationReadWrite, EnforcementStrict))
18+
assert.True(t, ShouldCallLabelResponse(OperationRead, EnforcementStrict))
19+
assert.True(t, ShouldCallLabelResponse(OperationReadWrite, EnforcementFilter))
20+
assert.True(t, ShouldCallLabelResponse(OperationReadWrite, EnforcementPropagate))
21+
}
22+
23+
func TestShouldBlockFilteredResponse(t *testing.T) {
24+
assert.True(t, ShouldBlockFilteredResponse(EnforcementStrict, 1))
25+
assert.False(t, ShouldBlockFilteredResponse(EnforcementStrict, 0))
26+
assert.False(t, ShouldBlockFilteredResponse(EnforcementFilter, 3))
27+
assert.False(t, ShouldBlockFilteredResponse(EnforcementPropagate, 2))
28+
}
29+
30+
func TestShouldAccumulateReadLabels(t *testing.T) {
31+
assert.True(t, ShouldAccumulateReadLabels(OperationRead, EnforcementPropagate))
32+
assert.True(t, ShouldAccumulateReadLabels(OperationReadWrite, EnforcementPropagate))
33+
assert.False(t, ShouldAccumulateReadLabels(OperationWrite, EnforcementPropagate))
34+
assert.False(t, ShouldAccumulateReadLabels(OperationRead, EnforcementStrict))
35+
assert.False(t, ShouldAccumulateReadLabels(OperationRead, EnforcementFilter))
36+
}

internal/logger/common.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,24 @@ import (
215215
// (markdown_logger.go) and logWithLevelAndServer (server_file_logger.go).
216216
// When adding a new LogLevel constant, add a corresponding entry here so
217217
// that all dispatch sites automatically support the new level.
218+
func makeLevelLogger(
219+
dispatch func(level LogLevel, category, format string, args ...interface{}),
220+
level LogLevel,
221+
) func(category, format string, args ...interface{}) {
222+
return func(category, format string, args ...interface{}) {
223+
dispatch(level, category, format, args...)
224+
}
225+
}
226+
227+
func makeServerLevelLogger(
228+
dispatch func(serverID string, level LogLevel, category, format string, args ...interface{}),
229+
level LogLevel,
230+
) func(serverID, category, format string, args ...interface{}) {
231+
return func(serverID, category, format string, args ...interface{}) {
232+
dispatch(serverID, level, category, format, args...)
233+
}
234+
}
235+
218236
var logFuncs = map[LogLevel]func(string, string, ...interface{}){
219237
LogLevelInfo: LogInfo,
220238
LogLevelWarn: LogWarn,

internal/logger/file_logger.go

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -113,25 +113,16 @@ func logWithLevel(level LogLevel, category, format string, args ...interface{})
113113
})
114114
}
115115

116-
// LogInfo logs an informational message
117-
func LogInfo(category, format string, args ...interface{}) {
118-
logWithLevel(LogLevelInfo, category, format, args...)
119-
}
120-
121-
// LogWarn logs a warning message
122-
func LogWarn(category, format string, args ...interface{}) {
123-
logWithLevel(LogLevelWarn, category, format, args...)
124-
}
125-
126-
// LogError logs an error message
127-
func LogError(category, format string, args ...interface{}) {
128-
logWithLevel(LogLevelError, category, format, args...)
129-
}
130-
131-
// LogDebug logs a debug message
132-
func LogDebug(category, format string, args ...interface{}) {
133-
logWithLevel(LogLevelDebug, category, format, args...)
134-
}
116+
var (
117+
// LogInfo logs an informational message.
118+
LogInfo = makeLevelLogger(logWithLevel, LogLevelInfo)
119+
// LogWarn logs a warning message.
120+
LogWarn = makeLevelLogger(logWithLevel, LogLevelWarn)
121+
// LogError logs an error message.
122+
LogError = makeLevelLogger(logWithLevel, LogLevelError)
123+
// LogDebug logs a debug message.
124+
LogDebug = makeLevelLogger(logWithLevel, LogLevelDebug)
125+
)
135126

136127
// CloseGlobalLogger closes the global file logger
137128
func CloseGlobalLogger() error {

internal/logger/markdown_logger.go

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -180,25 +180,16 @@ func logWithMarkdown(level LogLevel, category, format string, args ...interface{
180180
})
181181
}
182182

183-
// LogInfoMd logs to both regular and markdown loggers
184-
func LogInfoMd(category, format string, args ...interface{}) {
185-
logWithMarkdown(LogLevelInfo, category, format, args...)
186-
}
187-
188-
// LogWarnMd logs to both regular and markdown loggers
189-
func LogWarnMd(category, format string, args ...interface{}) {
190-
logWithMarkdown(LogLevelWarn, category, format, args...)
191-
}
192-
193-
// LogErrorMd logs to both regular and markdown loggers
194-
func LogErrorMd(category, format string, args ...interface{}) {
195-
logWithMarkdown(LogLevelError, category, format, args...)
196-
}
197-
198-
// LogDebugMd logs to both regular and markdown loggers
199-
func LogDebugMd(category, format string, args ...interface{}) {
200-
logWithMarkdown(LogLevelDebug, category, format, args...)
201-
}
183+
var (
184+
// LogInfoMd logs to both regular and markdown loggers.
185+
LogInfoMd = makeLevelLogger(logWithMarkdown, LogLevelInfo)
186+
// LogWarnMd logs to both regular and markdown loggers.
187+
LogWarnMd = makeLevelLogger(logWithMarkdown, LogLevelWarn)
188+
// LogErrorMd logs to both regular and markdown loggers.
189+
LogErrorMd = makeLevelLogger(logWithMarkdown, LogLevelError)
190+
// LogDebugMd logs to both regular and markdown loggers.
191+
LogDebugMd = makeLevelLogger(logWithMarkdown, LogLevelDebug)
192+
)
202193

203194
// CloseMarkdownLogger closes the global markdown logger
204195
func CloseMarkdownLogger() error {

internal/logger/server_file_logger.go

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -155,25 +155,16 @@ func logWithLevelAndServer(serverID string, level LogLevel, category, format str
155155
}
156156
}
157157

158-
// LogInfoWithServer logs an informational message to the server-specific log file
159-
func LogInfoWithServer(serverID, category, format string, args ...interface{}) {
160-
logWithLevelAndServer(serverID, LogLevelInfo, category, format, args...)
161-
}
162-
163-
// LogWarnWithServer logs a warning message to the server-specific log file
164-
func LogWarnWithServer(serverID, category, format string, args ...interface{}) {
165-
logWithLevelAndServer(serverID, LogLevelWarn, category, format, args...)
166-
}
167-
168-
// LogErrorWithServer logs an error message to the server-specific log file
169-
func LogErrorWithServer(serverID, category, format string, args ...interface{}) {
170-
logWithLevelAndServer(serverID, LogLevelError, category, format, args...)
171-
}
172-
173-
// LogDebugWithServer logs a debug message to the server-specific log file
174-
func LogDebugWithServer(serverID, category, format string, args ...interface{}) {
175-
logWithLevelAndServer(serverID, LogLevelDebug, category, format, args...)
176-
}
158+
var (
159+
// LogInfoWithServer logs an informational message to the server-specific log file.
160+
LogInfoWithServer = makeServerLevelLogger(logWithLevelAndServer, LogLevelInfo)
161+
// LogWarnWithServer logs a warning message to the server-specific log file.
162+
LogWarnWithServer = makeServerLevelLogger(logWithLevelAndServer, LogLevelWarn)
163+
// LogErrorWithServer logs an error message to the server-specific log file.
164+
LogErrorWithServer = makeServerLevelLogger(logWithLevelAndServer, LogLevelError)
165+
// LogDebugWithServer logs a debug message to the server-specific log file.
166+
LogDebugWithServer = makeServerLevelLogger(logWithLevelAndServer, LogLevelDebug)
167+
)
177168

178169
// CloseServerFileLogger closes the global server file logger
179170
func CloseServerFileLogger() error {

internal/proxy/handler.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ func (h *proxyHandler) handleWithDIFC(w http.ResponseWriter, r *http.Request, pa
178178
evalResult := s.evaluator.Evaluate(agentLabels.Secrecy, agentLabels.Integrity, resource, operation)
179179

180180
if !evalResult.IsAllowed() {
181-
if operation == difc.OperationRead {
181+
if difc.ShouldBypassCoarseDeny(operation) {
182182
// Read in filter mode: skip coarse block, proceed to fine-grained filtering
183183
logHandler.Printf("[DIFC] Phase 2: coarse check failed for read, proceeding to Phase 3")
184184
} else {
@@ -266,7 +266,7 @@ func (h *proxyHandler) handleWithDIFC(w http.ResponseWriter, r *http.Request, pa
266266
}
267267

268268
// Strict mode: block entire response if any item filtered
269-
if s.enforcementMode == difc.EnforcementStrict && filtered.GetFilteredCount() > 0 {
269+
if difc.ShouldBlockFilteredResponse(s.enforcementMode, filtered.GetFilteredCount()) {
270270
logHandler.Printf("[DIFC] STRICT: blocking response — %d filtered items", filtered.GetFilteredCount())
271271
writeDIFCForbidden(w, fmt.Sprintf("DIFC policy violation: %d of %d items not accessible",
272272
filtered.GetFilteredCount(), filtered.TotalCount))
@@ -318,7 +318,7 @@ func (h *proxyHandler) handleWithDIFC(w http.ResponseWriter, r *http.Request, pa
318318
}
319319

320320
// **Phase 6: Label accumulation (propagate mode)**
321-
if s.enforcementMode == difc.EnforcementPropagate && labeledData != nil {
321+
if labeledData != nil && difc.ShouldAccumulateReadLabels(operation, s.enforcementMode) {
322322
overall := labeledData.Overall()
323323
agentLabels.AccumulateFromRead(overall)
324324
logHandler.Printf("[DIFC] Phase 6: accumulated labels")

internal/server/unified.go

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,7 @@ func (us *UnifiedServer) callBackendTool(ctx context.Context, serverID, toolName
531531
// For read operations in any mode, we skip the coarse-grained block
532532
// and let the request proceed. Fine-grained filtering at Phase 5 will filter
533533
// individual items from the response based on their actual labels from LabelResponse().
534-
isReadOperation := (operation == difc.OperationRead)
534+
isReadOperation := difc.ShouldBypassCoarseDeny(operation)
535535
result := requestEvaluator.Evaluate(agentLabels.Secrecy, agentLabels.Integrity, resource, operation)
536536

537537
if !result.IsAllowed() {
@@ -603,8 +603,7 @@ func (us *UnifiedServer) callBackendTool(ctx context.Context, serverID, toolName
603603
// Per spec: LabelResponse() is only called for read operations in all modes,
604604
// and for read-write operations in filter/propagate modes.
605605
// For write operations and read-write in strict mode, skip LabelResponse().
606-
isPureWrite := (operation == difc.OperationWrite)
607-
shouldCallLabelResponse := !isPureWrite && (operation != difc.OperationReadWrite || enforcementMode != difc.EnforcementStrict)
606+
shouldCallLabelResponse := difc.ShouldCallLabelResponse(operation, enforcementMode)
608607

609608
var labeledData difc.LabeledData
610609
if shouldCallLabelResponse {
@@ -631,7 +630,7 @@ func (us *UnifiedServer) callBackendTool(ctx context.Context, serverID, toolName
631630
filtered.GetAccessibleCount(), filtered.TotalCount)
632631

633632
// **Strict mode: block entire response if ANY item is filtered**
634-
if enforcementMode == difc.EnforcementStrict && filtered.GetFilteredCount() > 0 {
633+
if difc.ShouldBlockFilteredResponse(enforcementMode, filtered.GetFilteredCount()) {
635634
logger.LogWarn("difc", "STRICT MODE: Blocking entire response - %d/%d items violate DIFC policy",
636635
filtered.GetFilteredCount(), filtered.TotalCount)
637636
blockErr := fmt.Errorf("DIFC policy violation: %d of %d items in response are not accessible to agent %s",
@@ -664,7 +663,7 @@ func (us *UnifiedServer) callBackendTool(ctx context.Context, serverID, toolName
664663
// **Phase 6: Accumulate labels from this operation (for reads in PROPAGATE mode only)**
665664
// Label accumulation should only happen when mode is EnforcementPropagate
666665
// Filter mode does NOT accumulate - it just filters what the agent can see
667-
if !isPureWrite && enforcementMode == difc.EnforcementPropagate {
666+
if difc.ShouldAccumulateReadLabels(operation, enforcementMode) {
668667
overall := labeledData.Overall()
669668
agentLabels.AccumulateFromRead(overall)
670669
logUnified.Printf("[DIFC] Agent %s accumulated labels (propagate mode) | Secrecy: %v | Integrity: %v",
@@ -675,7 +674,7 @@ func (us *UnifiedServer) callBackendTool(ctx context.Context, serverID, toolName
675674
finalResult = backendResult
676675

677676
// **Phase 6: Accumulate labels from resource (for reads in PROPAGATE mode only)**
678-
if !isPureWrite && enforcementMode == difc.EnforcementPropagate {
677+
if difc.ShouldAccumulateReadLabels(operation, enforcementMode) {
679678
agentLabels.AccumulateFromRead(resource)
680679
logUnified.Printf("[DIFC] Agent %s accumulated labels (propagate mode) | Secrecy: %v | Integrity: %v",
681680
agentID, agentLabels.GetSecrecyTags(), agentLabels.GetIntegrityTags())

0 commit comments

Comments
 (0)