Skip to content

Commit 4e07a6b

Browse files
feat: get access key from request and read valid access keys from file
issue: #218
1 parent aff116c commit 4e07a6b

14 files changed

+395
-48
lines changed

lapis2/src/main/kotlin/org/genspectrum/lapis/LapisSpringConfig.kt

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
package org.genspectrum.lapis
22

3-
import com.fasterxml.jackson.databind.ObjectMapper
4-
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory
53
import com.fasterxml.jackson.module.kotlin.readValue
6-
import com.fasterxml.jackson.module.kotlin.registerKotlinModule
74
import mu.KotlinLogging
8-
import org.genspectrum.lapis.auth.DataOpennessAuthorizationFilter
5+
import org.genspectrum.lapis.auth.DataOpennessAuthorizationFilterFactory
96
import org.genspectrum.lapis.config.DatabaseConfig
107
import org.genspectrum.lapis.config.SequenceFilterFields
118
import org.genspectrum.lapis.logging.RequestContext
129
import org.genspectrum.lapis.logging.RequestContextLogger
1310
import org.genspectrum.lapis.logging.StatisticsLogObjectMapper
1411
import org.genspectrum.lapis.util.TimeFactory
12+
import org.genspectrum.lapis.util.YamlObjectMapper
1513
import org.springframework.beans.factory.annotation.Value
1614
import org.springframework.context.annotation.Bean
1715
import org.springframework.context.annotation.Configuration
@@ -24,8 +22,11 @@ class LapisSpringConfig {
2422
fun openAPI(sequenceFilterFields: SequenceFilterFields) = buildOpenApiSchema(sequenceFilterFields)
2523

2624
@Bean
27-
fun databaseConfig(@Value("\${lapis.databaseConfig.path}") configPath: String): DatabaseConfig {
28-
return ObjectMapper(YAMLFactory()).registerKotlinModule().readValue(File(configPath))
25+
fun databaseConfig(
26+
@Value("\${lapis.databaseConfig.path}") configPath: String,
27+
yamlObjectMapper: YamlObjectMapper,
28+
): DatabaseConfig {
29+
return yamlObjectMapper.objectMapper.readValue(File(configPath))
2930
}
3031

3132
@Bean
@@ -55,6 +56,7 @@ class LapisSpringConfig {
5556
)
5657

5758
@Bean
58-
fun dataOpennessAuthorizationFilter(databaseConfig: DatabaseConfig, objectMapper: ObjectMapper) =
59-
DataOpennessAuthorizationFilter.createFromConfig(databaseConfig, objectMapper)
59+
fun dataOpennessAuthorizationFilter(
60+
dataOpennessAuthorizationFilterFactory: DataOpennessAuthorizationFilterFactory,
61+
) = dataOpennessAuthorizationFilterFactory.create()
6062
}
Lines changed: 85 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,52 @@
11
package org.genspectrum.lapis.auth
22

33
import com.fasterxml.jackson.databind.ObjectMapper
4+
import com.fasterxml.jackson.module.kotlin.readValue
45
import jakarta.servlet.FilterChain
56
import jakarta.servlet.http.HttpServletRequest
67
import jakarta.servlet.http.HttpServletResponse
8+
import mu.KotlinLogging
9+
import org.genspectrum.lapis.config.AccessKeys
10+
import org.genspectrum.lapis.config.AccessKeysReader
711
import org.genspectrum.lapis.config.DatabaseConfig
812
import org.genspectrum.lapis.config.OpennessLevel
913
import org.genspectrum.lapis.controller.LapisHttpErrorResponse
14+
import org.genspectrum.lapis.util.CachedBodyHttpServletRequest
1015
import org.springframework.http.HttpStatus
1116
import org.springframework.http.MediaType
17+
import org.springframework.stereotype.Component
1218
import org.springframework.web.filter.OncePerRequestFilter
1319

14-
abstract class DataOpennessAuthorizationFilter(val objectMapper: ObjectMapper) : OncePerRequestFilter() {
20+
const val ACCESS_KEY_PROPERTY = "accessKey"
21+
22+
private val log = KotlinLogging.logger {}
23+
24+
@Component
25+
class DataOpennessAuthorizationFilterFactory(
26+
private val databaseConfig: DatabaseConfig,
27+
private val objectMapper: ObjectMapper,
28+
private val accessKeysReader: AccessKeysReader,
29+
) {
30+
fun create() = when (databaseConfig.schema.opennessLevel) {
31+
OpennessLevel.OPEN -> AlwaysAuthorizedAuthorizationFilter(objectMapper)
32+
OpennessLevel.GISAID -> ProtectedGisaidDataAuthorizationFilter(
33+
objectMapper,
34+
accessKeysReader.read(),
35+
databaseConfig.schema.metadata.filter { it.unique }.map { it.name },
36+
)
37+
}
38+
}
39+
40+
abstract class DataOpennessAuthorizationFilter(protected val objectMapper: ObjectMapper) : OncePerRequestFilter() {
1541
override fun doFilterInternal(
1642
request: HttpServletRequest,
1743
response: HttpServletResponse,
1844
filterChain: FilterChain,
1945
) {
20-
when (val result = isAuthorizedForEndpoint(request)) {
21-
AuthorizationResult.Success -> filterChain.doFilter(request, response)
46+
val reReadableRequest = CachedBodyHttpServletRequest(request)
47+
48+
when (val result = isAuthorizedForEndpoint(reReadableRequest)) {
49+
AuthorizationResult.Success -> filterChain.doFilter(reReadableRequest, response)
2250
is AuthorizationResult.Failure -> {
2351
response.status = HttpStatus.FORBIDDEN.value()
2452
response.contentType = MediaType.APPLICATION_JSON_VALUE
@@ -34,15 +62,7 @@ abstract class DataOpennessAuthorizationFilter(val objectMapper: ObjectMapper) :
3462
}
3563
}
3664

37-
abstract fun isAuthorizedForEndpoint(request: HttpServletRequest): AuthorizationResult
38-
39-
companion object {
40-
fun createFromConfig(databaseConfig: DatabaseConfig, objectMapper: ObjectMapper) =
41-
when (databaseConfig.schema.opennessLevel) {
42-
OpennessLevel.OPEN -> NoOpAuthorizationFilter(objectMapper)
43-
OpennessLevel.GISAID -> ProtectedGisaidDataAuthorizationFilter(objectMapper)
44-
}
45-
}
65+
abstract fun isAuthorizedForEndpoint(request: CachedBodyHttpServletRequest): AuthorizationResult
4666
}
4767

4868
sealed interface AuthorizationResult {
@@ -52,24 +72,64 @@ sealed interface AuthorizationResult {
5272
fun failure(message: String): AuthorizationResult = Failure(message)
5373
}
5474

55-
fun isSuccessful(): Boolean
56-
57-
object Success : AuthorizationResult {
58-
override fun isSuccessful() = true
59-
}
75+
object Success : AuthorizationResult
6076

61-
class Failure(val message: String) : AuthorizationResult {
62-
override fun isSuccessful() = false
63-
}
77+
class Failure(val message: String) : AuthorizationResult
6478
}
6579

66-
private class NoOpAuthorizationFilter(objectMapper: ObjectMapper) : DataOpennessAuthorizationFilter(objectMapper) {
67-
override fun isAuthorizedForEndpoint(request: HttpServletRequest) = AuthorizationResult.success()
80+
private class AlwaysAuthorizedAuthorizationFilter(objectMapper: ObjectMapper) :
81+
DataOpennessAuthorizationFilter(objectMapper) {
82+
83+
override fun isAuthorizedForEndpoint(request: CachedBodyHttpServletRequest) = AuthorizationResult.success()
6884
}
6985

70-
private class ProtectedGisaidDataAuthorizationFilter(objectMapper: ObjectMapper) :
86+
private class ProtectedGisaidDataAuthorizationFilter(
87+
objectMapper: ObjectMapper,
88+
private val accessKeys: AccessKeys,
89+
private val fieldsThatServeNonAggregatedData: List<String>,
90+
) :
7191
DataOpennessAuthorizationFilter(objectMapper) {
7292

73-
override fun isAuthorizedForEndpoint(request: HttpServletRequest) =
74-
AuthorizationResult.failure("An access key is required to access this endpoint.")
93+
companion object {
94+
private val ENDPOINTS_THAT_SERVE_AGGREGATED_DATA = listOf("/aggregated", "/nucleotideMutations")
95+
}
96+
97+
override fun isAuthorizedForEndpoint(request: CachedBodyHttpServletRequest): AuthorizationResult {
98+
val requestFields = getRequestFields(request)
99+
100+
val accessKey = requestFields[ACCESS_KEY_PROPERTY]
101+
?: return AuthorizationResult.failure("An access key is required to access this endpoint.")
102+
103+
if (accessKeys.fullAccessKey == accessKey) {
104+
return AuthorizationResult.success()
105+
}
106+
107+
val endpointServesAggregatedData = ENDPOINTS_THAT_SERVE_AGGREGATED_DATA.contains(request.requestURI) &&
108+
fieldsThatServeNonAggregatedData.intersect(requestFields.keys).isEmpty()
109+
110+
if (endpointServesAggregatedData && accessKeys.aggregatedDataAccessKey == accessKey) {
111+
return AuthorizationResult.success()
112+
}
113+
114+
return AuthorizationResult.failure("You are not authorized to access this endpoint.")
115+
}
116+
117+
private fun getRequestFields(request: CachedBodyHttpServletRequest): Map<String, String> {
118+
if (request.parameterNames.hasMoreElements()) {
119+
return request.parameterMap.mapValues { (_, value) -> value.joinToString() }
120+
}
121+
122+
if (request.contentLength == 0) {
123+
log.warn { "Could not read access key from body, because content length is 0." }
124+
return emptyMap()
125+
}
126+
127+
return try {
128+
objectMapper.readValue(request.inputStream)
129+
} catch (exception: Exception) {
130+
log.error { "Failed to read access key from request body: ${exception.message}" }
131+
log.debug { exception.stackTraceToString() }
132+
emptyMap()
133+
}
134+
}
75135
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package org.genspectrum.lapis.config
2+
3+
import com.fasterxml.jackson.module.kotlin.readValue
4+
import org.genspectrum.lapis.util.YamlObjectMapper
5+
import org.springframework.beans.factory.annotation.Value
6+
import org.springframework.stereotype.Component
7+
import java.io.File
8+
9+
@Component
10+
class AccessKeysReader(
11+
@Value("\${lapis.accessKeys.path:#{null}}") private val accessKeysFile: String?,
12+
private val yamlObjectMapper: YamlObjectMapper,
13+
) {
14+
fun read(): AccessKeys {
15+
if (accessKeysFile == null) {
16+
throw IllegalArgumentException("Cannot read LAPIS access keys, lapis.accessKeys.path was not set.")
17+
}
18+
19+
return yamlObjectMapper.objectMapper.readValue(File(accessKeysFile))
20+
}
21+
}
22+
23+
data class AccessKeys(val fullAccessKey: String, val aggregatedDataAccessKey: String)

lapis2/src/main/kotlin/org/genspectrum/lapis/config/DatabaseConfig.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ data class DatabaseSchema(
1010
val features: List<DatabaseFeature> = emptyList(),
1111
)
1212

13-
data class DatabaseMetadata(val name: String, val type: String)
13+
data class DatabaseMetadata(val name: String, val type: String, val unique: Boolean = false)
1414

1515
data class DatabaseFeature(val name: String)
1616

lapis2/src/main/kotlin/org/genspectrum/lapis/controller/LapisController.kt

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import io.swagger.v3.oas.annotations.media.ArraySchema
88
import io.swagger.v3.oas.annotations.media.Content
99
import io.swagger.v3.oas.annotations.media.Schema
1010
import io.swagger.v3.oas.annotations.responses.ApiResponse
11+
import org.genspectrum.lapis.auth.ACCESS_KEY_PROPERTY
1112
import org.genspectrum.lapis.logging.RequestContext
1213
import org.genspectrum.lapis.model.SiloQueryModel
1314
import org.genspectrum.lapis.response.AggregatedResponse
@@ -27,6 +28,9 @@ private const val DEFAULT_MIN_PROPORTION = 0.05
2728

2829
@RestController
2930
class LapisController(private val siloQueryModel: SiloQueryModel, private val requestContext: RequestContext) {
31+
companion object {
32+
private val nonSequenceFilterFields = listOf(MIN_PROPORTION_PROPERTY, ACCESS_KEY_PROPERTY)
33+
}
3034

3135
@GetMapping("/aggregated")
3236
@LapisAggregatedResponse
@@ -41,7 +45,7 @@ class LapisController(private val siloQueryModel: SiloQueryModel, private val re
4145
): AggregatedResponse {
4246
requestContext.filter = sequenceFilters
4347

44-
return siloQueryModel.aggregate(sequenceFilters)
48+
return siloQueryModel.aggregate(sequenceFilters.filterKeys { !nonSequenceFilterFields.contains(it) })
4549
}
4650

4751
@PostMapping("/aggregated")
@@ -53,7 +57,7 @@ class LapisController(private val siloQueryModel: SiloQueryModel, private val re
5357
): AggregatedResponse {
5458
requestContext.filter = sequenceFilters
5559

56-
return siloQueryModel.aggregate(sequenceFilters)
60+
return siloQueryModel.aggregate(sequenceFilters.filterKeys { !nonSequenceFilterFields.contains(it) })
5761
}
5862

5963
@GetMapping("/nucleotideMutations")
@@ -72,7 +76,7 @@ class LapisController(private val siloQueryModel: SiloQueryModel, private val re
7276

7377
return siloQueryModel.computeMutationProportions(
7478
minProportion,
75-
sequenceFilters.filterKeys { it != MIN_PROPORTION_PROPERTY },
79+
sequenceFilters.filterKeys { !nonSequenceFilterFields.contains(it) },
7680
)
7781
}
7882

@@ -85,9 +89,11 @@ class LapisController(private val siloQueryModel: SiloQueryModel, private val re
8589
): List<MutationData> {
8690
requestContext.filter = requestBody
8791

88-
val (minProportions, sequenceFilters) = requestBody.entries.partition { it.key == MIN_PROPORTION_PROPERTY }
92+
val (nonSequenceFilters, sequenceFilters) = requestBody.entries.partition {
93+
nonSequenceFilterFields.contains(it.key)
94+
}
8995

90-
val maybeMinProportion = minProportions.getOrNull(0)?.value
96+
val maybeMinProportion = nonSequenceFilters.find { it.key == MIN_PROPORTION_PROPERTY }?.value
9197
val minProportion = try {
9298
maybeMinProportion?.toDouble() ?: DEFAULT_MIN_PROPORTION
9399
} catch (exception: IllegalArgumentException) {
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package org.genspectrum.lapis.util
2+
3+
import jakarta.servlet.ReadListener
4+
import jakarta.servlet.ServletInputStream
5+
import jakarta.servlet.http.HttpServletRequest
6+
import jakarta.servlet.http.HttpServletRequestWrapper
7+
import java.io.ByteArrayInputStream
8+
import java.io.ByteArrayOutputStream
9+
import java.io.IOException
10+
import java.io.InputStream
11+
12+
class CachedBodyHttpServletRequest(request: HttpServletRequest) : HttpServletRequestWrapper(request) {
13+
private val cachedBody: ByteArray by lazy {
14+
val inputStream: InputStream = request.inputStream
15+
val byteArrayOutputStream = ByteArrayOutputStream()
16+
17+
inputStream.copyTo(byteArrayOutputStream)
18+
byteArrayOutputStream.toByteArray()
19+
}
20+
21+
@Throws(IOException::class)
22+
override fun getInputStream(): ServletInputStream {
23+
return CachedBodyServletInputStream(ByteArrayInputStream(cachedBody))
24+
}
25+
26+
private inner class CachedBodyServletInputStream(private val cachedInputStream: ByteArrayInputStream) :
27+
ServletInputStream() {
28+
29+
override fun isFinished(): Boolean {
30+
return cachedInputStream.available() == 0
31+
}
32+
33+
override fun isReady(): Boolean {
34+
return true
35+
}
36+
37+
override fun setReadListener(listener: ReadListener) {
38+
throw UnsupportedOperationException("setReadListener is not supported")
39+
}
40+
41+
@Throws(IOException::class)
42+
override fun read(): Int {
43+
return cachedInputStream.read()
44+
}
45+
}
46+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package org.genspectrum.lapis.util
2+
3+
import com.fasterxml.jackson.databind.ObjectMapper
4+
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory
5+
import com.fasterxml.jackson.module.kotlin.registerKotlinModule
6+
import org.springframework.stereotype.Component
7+
8+
@Component
9+
object YamlObjectMapper {
10+
val objectMapper: ObjectMapper = ObjectMapper(YAMLFactory()).registerKotlinModule()
11+
}

0 commit comments

Comments
 (0)