Skip to content

Commit

Permalink
added client side authentication support while saving malicious and s…
Browse files Browse the repository at this point in the history
…mart events
  • Loading branch information
ag060 committed Nov 22, 2024
1 parent a474ce8 commit 0b5fc94
Show file tree
Hide file tree
Showing 5 changed files with 251 additions and 162 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package com.akto.auth.grpc;

import io.grpc.CallCredentials;
import io.grpc.Metadata;
import io.grpc.Status;

import java.util.concurrent.Executor;

import static io.grpc.Metadata.ASCII_STRING_MARSHALLER;

public class AuthToken extends CallCredentials {

private final String token;
public static final Metadata.Key<String> AUTHORIZATION_METADATA_KEY =
Metadata.Key.of("Authorization", ASCII_STRING_MARSHALLER);

public AuthToken(String token) {
this.token = token;
}

@Override
public void applyRequestMetadata(
RequestInfo requestInfo, Executor appExecutor, MetadataApplier applier) {
appExecutor.execute(
() -> {
try {
Metadata headers = new Metadata();
headers.put(AUTHORIZATION_METADATA_KEY, token);
applier.apply(headers);
} catch (Throwable e) {
applier.fail(Status.UNAUTHENTICATED.withCause(e));
}
});
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import java.util.*;

import com.akto.auth.grpc.AuthToken;
import com.akto.cache.RedisBackedCounterCache;
import com.akto.dao.context.Context;
import com.akto.dao.monitoring.FilterYamlTemplateDao;
Expand All @@ -22,11 +23,8 @@
import com.akto.kafka.Kafka;
import com.akto.log.LoggerMaker;
import com.akto.log.LoggerMaker.LogDb;
import com.akto.proto.threat_protection.consumer_service.v1.*;
import com.akto.proto.threat_protection.consumer_service.v1.ConsumerServiceGrpc.ConsumerServiceBlockingStub;
import com.akto.proto.threat_protection.consumer_service.v1.ConsumerServiceGrpc;
import com.akto.proto.threat_protection.consumer_service.v1.MaliciousEvent;
import com.akto.proto.threat_protection.consumer_service.v1.SaveSmartEventRequest;
import com.akto.proto.threat_protection.consumer_service.v1.SmartEvent;
import com.akto.rules.TestPlugin;
import com.akto.suspect_data.Message;
import com.akto.test_editor.execution.VariableResolver;
Expand All @@ -37,6 +35,7 @@
import io.grpc.Grpc;
import io.grpc.InsecureChannelCredentials;
import io.grpc.ManagedChannel;
import io.grpc.stub.StreamObserver;
import io.lettuce.core.RedisClient;

public class HttpCallFilter {
Expand All @@ -58,7 +57,7 @@ public class HttpCallFilter {

private final WindowBasedThresholdNotifier windowBasedThresholdNotifier;

private final ConsumerServiceBlockingStub consumerServiceBlockingStub;
private final ConsumerServiceGrpc.ConsumerServiceStub consumerServiceStub;

public HttpCallFilter(
RedisClient redisClient, int sync_threshold_count, int sync_threshold_time) {
Expand All @@ -76,7 +75,10 @@ public HttpCallFilter(
String target = "localhost:8980";
ManagedChannel channel =
Grpc.newChannelBuilder(target, InsecureChannelCredentials.create()).build();
this.consumerServiceBlockingStub = ConsumerServiceGrpc.newBlockingStub(channel);
this.consumerServiceStub =
ConsumerServiceGrpc.newStub(channel)
.withCallCredentials(
new AuthToken(System.getenv("AKTO_THREAT_PROTECTION_BACKEND_TOKEN")));
}

public void filterFunction(List<HttpResponseParams> responseParams) {
Expand Down Expand Up @@ -110,6 +112,7 @@ public void filterFunction(List<HttpResponseParams> responseParams) {
List<Rule> rules = new ArrayList<>();
rules.add(new Rule("Lfi Rule 1", new Condition(100, 10)));
AggregationRules aggRules = new AggregationRules();
aggRules.setRule(rules);

SourceIPKeyGenerator.instance
.generate(responseParam)
Expand All @@ -131,7 +134,13 @@ public void filterFunction(List<HttpResponseParams> responseParams) {
.setTimestamp(responseParam.getTime())
.build();

maliciousSamples.add(new Message(responseParam.getAccountId(), maliciousEvent));
try {
String data = JsonFormat.printer().print(maliciousEvent);
maliciousSamples.add(new Message(responseParam.getAccountId(), data));
} catch (InvalidProtocolBufferException e) {
e.printStackTrace();
return;
}

for (Rule rule : aggRules.getRule()) {
WindowBasedThresholdNotifier.Result result =
Expand All @@ -146,8 +155,24 @@ public void filterFunction(List<HttpResponseParams> responseParams) {
.setDetectedAt(responseParam.getTime())
.setRuleId(rule.getName())
.build();
this.consumerServiceBlockingStub.saveSmartEvent(
SaveSmartEventRequest.newBuilder().setEvent(smartEvent).build());
this.consumerServiceStub.saveSmartEvent(
SaveSmartEventRequest.newBuilder().setEvent(smartEvent).build(),
new StreamObserver<SaveSmartEventResponse>() {
@Override
public void onNext(SaveSmartEventResponse value) {
// Do nothing
}

@Override
public void onError(Throwable t) {
t.printStackTrace();
}

@Override
public void onCompleted() {
// Do nothing
}
});
}
}
});
Expand All @@ -160,12 +185,12 @@ public void filterFunction(List<HttpResponseParams> responseParams) {
try {
maliciousSamples.forEach(
sample -> {
try {
String data = JsonFormat.printer().print(null);
kafka.send(data, KAFKA_MALICIOUS_TOPIC);
} catch (InvalidProtocolBufferException e) {
e.printStackTrace();
}
sample
.marshal()
.ifPresent(
data -> {
kafka.send(data, KAFKA_MALICIOUS_TOPIC);
});
});
} catch (Exception e) {
e.printStackTrace();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,18 @@
package com.akto.suspect_data;

import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.*;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

import com.akto.auth.grpc.AuthToken;
import com.akto.proto.threat_protection.consumer_service.v1.ConsumerServiceGrpc;
import com.akto.proto.threat_protection.consumer_service.v1.ConsumerServiceGrpc.ConsumerServiceStub;
import com.akto.proto.threat_protection.consumer_service.v1.MaliciousEvent;
import com.akto.proto.threat_protection.consumer_service.v1.SaveMaliciousEventRequest;
import com.akto.proto.threat_protection.consumer_service.v1.SaveMaliciousEventResponse;

import io.grpc.ChannelCredentials;
import org.apache.kafka.clients.consumer.Consumer;
import org.apache.kafka.clients.consumer.ConsumerRecord;
import org.apache.kafka.clients.consumer.ConsumerRecords;
Expand Down Expand Up @@ -45,11 +42,13 @@ private FlushMessagesTask() {
this.consumer = new KafkaConsumer<>(properties);

String target = System.getenv("AKTO_THREAT_PROTECTION_BACKEND_URL");
// TODO: Secure this connection
ManagedChannel channel =
Grpc.newChannelBuilder(target, InsecureChannelCredentials.create()).build();

this.asyncStub = ConsumerServiceGrpc.newStub(channel);
this.asyncStub =
ConsumerServiceGrpc.newStub(channel)
.withCallCredentials(
new AuthToken(System.getenv("AKTO_THREAT_PROTECTION_BACKEND_TOKEN")));
}

public static FlushMessagesTask instance = new FlushMessagesTask();
Expand All @@ -66,7 +65,6 @@ public void run() {
processRecords(records);
} catch (Exception e) {
e.printStackTrace();
consumer.close();
}
}
}
Expand All @@ -78,9 +76,14 @@ public void processRecords(ConsumerRecords<String, String> records) {
for (ConsumerRecord<String, String> record : records) {
try {
MaliciousEvent.Builder builder = MaliciousEvent.newBuilder();
JsonFormat.parser().merge(record.value(), builder);
Message m = Message.unmarshal(record.value()).orElse(null);
if (m == null) {
continue;
}

JsonFormat.parser().merge(m.getData(), builder);
MaliciousEvent event = builder.build();
accWiseMessages.computeIfAbsent(record.key(), k -> new ArrayList<>()).add(event);
accWiseMessages.computeIfAbsent(m.getAccountId(), k -> new ArrayList<>()).add(event);
} catch (InvalidProtocolBufferException e) {
e.printStackTrace();
}
Expand All @@ -89,7 +92,6 @@ public void processRecords(ConsumerRecords<String, String> records) {
for (Map.Entry<String, List<MaliciousEvent>> entry : accWiseMessages.entrySet()) {
int accountId = Integer.parseInt(entry.getKey());
List<MaliciousEvent> events = entry.getValue();
Context.accountId.set(accountId);

this.asyncStub.saveMaliciousEvent(
SaveMaliciousEventRequest.newBuilder().addAllEvents(events).build(),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,33 +1,57 @@
package com.akto.suspect_data;

import com.akto.proto.threat_protection.consumer_service.v1.MaliciousEvent;
import com.fasterxml.jackson.databind.ObjectMapper;

import java.util.Optional;

// Kafka Message Wrapper for suspect data
public class Message {
private String accountId;
private MaliciousEvent data;
private String accountId;
private String data;

public Message() {
}
private static ObjectMapper objectMapper = new ObjectMapper();

public Message(String accountId, MaliciousEvent data) {
this.accountId = accountId;
this.data = data;
}
public Message() {}

public String getAccountId() {
return accountId;
}
public Message(String accountId, String data) {
this.accountId = accountId;
this.data = data;
}

public void setAccountId(String accountId) {
this.accountId = accountId;
}
public String getAccountId() {
return accountId;
}

public MaliciousEvent getData() {
return data;
public void setAccountId(String accountId) {
this.accountId = accountId;
}

public String getData() {
return data;
}

public void setData(String data) {
this.data = data;
}

public Optional<String> marshal() {
try {
return Optional.ofNullable(objectMapper.writeValueAsString(this));
} catch (Exception e) {
e.printStackTrace();
}

public void setData(MaliciousEvent data) {
this.data = data;
return Optional.empty();
}

public static Optional<Message> unmarshal(String message) {
try {
return Optional.ofNullable(objectMapper.readValue(message, Message.class));
} catch (Exception e) {
e.printStackTrace();
}

return Optional.empty();
}
}
Loading

0 comments on commit 0b5fc94

Please sign in to comment.