1pub mod model;
8
9use std::sync::Arc;
10
11use arc_swap::ArcSwap;
12use mas_data_model::{SessionLimitConfig, Ulid};
13use opa_wasm::{
14 Runtime,
15 wasmtime::{Config, Engine, Module, OptLevel, Store},
16};
17use serde::Serialize;
18use thiserror::Error;
19use tokio::io::{AsyncRead, AsyncReadExt};
20
21pub use self::model::{
22 AuthorizationGrantInput, ClientRegistrationInput, Code as ViolationCode, EmailInput,
23 EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester, Violation,
24};
25
26#[derive(Debug, Error)]
27pub enum LoadError {
28 #[error("failed to read module")]
29 Read(#[from] tokio::io::Error),
30
31 #[error("failed to create WASM engine")]
32 Engine(#[source] anyhow::Error),
33
34 #[error("module compilation task crashed")]
35 CompilationTask(#[from] tokio::task::JoinError),
36
37 #[error("failed to compile WASM module")]
38 Compilation(#[source] anyhow::Error),
39
40 #[error("invalid policy data")]
41 InvalidData(#[source] anyhow::Error),
42
43 #[error("failed to instantiate a test instance")]
44 Instantiate(#[source] InstantiateError),
45}
46
47impl LoadError {
48 #[doc(hidden)]
51 #[must_use]
52 pub fn invalid_data_example() -> Self {
53 Self::InvalidData(anyhow::Error::msg("Failed to merge policy data objects"))
54 }
55}
56
57#[derive(Debug, Error)]
58pub enum InstantiateError {
59 #[error("failed to create WASM runtime")]
60 Runtime(#[source] anyhow::Error),
61
62 #[error("missing entrypoint {entrypoint}")]
63 MissingEntrypoint { entrypoint: String },
64
65 #[error("failed to load policy data")]
66 LoadData(#[source] anyhow::Error),
67}
68
69#[derive(Debug, Clone)]
71pub struct Entrypoints {
72 pub register: String,
73 pub client_registration: String,
74 pub authorization_grant: String,
75 pub email: String,
76}
77
78impl Entrypoints {
79 fn all(&self) -> [&str; 4] {
80 [
81 self.register.as_str(),
82 self.client_registration.as_str(),
83 self.authorization_grant.as_str(),
84 self.email.as_str(),
85 ]
86 }
87}
88
89#[derive(Debug)]
90pub struct Data {
91 base: BaseData,
92
93 rest: Option<serde_json::Value>,
95}
96
97#[derive(Serialize, Debug)]
98struct BaseData {
99 server_name: String,
100
101 session_limit: Option<SessionLimitConfig>,
103}
104
105impl Data {
106 #[must_use]
107 pub fn new(server_name: String, session_limit: Option<SessionLimitConfig>) -> Self {
108 Self {
109 base: BaseData {
110 server_name,
111 session_limit,
112 },
113
114 rest: None,
115 }
116 }
117
118 #[must_use]
119 pub fn with_rest(mut self, rest: serde_json::Value) -> Self {
120 self.rest = Some(rest);
121 self
122 }
123
124 fn to_value(&self) -> Result<serde_json::Value, anyhow::Error> {
125 let base = serde_json::to_value(&self.base)?;
126
127 if let Some(rest) = &self.rest {
128 merge_data(base, rest.clone())
129 } else {
130 Ok(base)
131 }
132 }
133}
134
135fn value_kind(value: &serde_json::Value) -> &'static str {
136 match value {
137 serde_json::Value::Object(_) => "object",
138 serde_json::Value::Array(_) => "array",
139 serde_json::Value::String(_) => "string",
140 serde_json::Value::Number(_) => "number",
141 serde_json::Value::Bool(_) => "boolean",
142 serde_json::Value::Null => "null",
143 }
144}
145
146fn merge_data(
147 mut left: serde_json::Value,
148 right: serde_json::Value,
149) -> Result<serde_json::Value, anyhow::Error> {
150 merge_data_rec(&mut left, right)?;
151 Ok(left)
152}
153
154fn merge_data_rec(
155 left: &mut serde_json::Value,
156 right: serde_json::Value,
157) -> Result<(), anyhow::Error> {
158 match (left, right) {
159 (serde_json::Value::Object(left), serde_json::Value::Object(right)) => {
160 for (key, value) in right {
161 if let Some(left_value) = left.get_mut(&key) {
162 merge_data_rec(left_value, value)?;
163 } else {
164 left.insert(key, value);
165 }
166 }
167 }
168 (serde_json::Value::Array(left), serde_json::Value::Array(right)) => {
169 left.extend(right);
170 }
171 (serde_json::Value::Number(left), serde_json::Value::Number(right)) => {
173 *left = right;
174 }
175 (serde_json::Value::Bool(left), serde_json::Value::Bool(right)) => {
176 *left = right;
177 }
178 (serde_json::Value::String(left), serde_json::Value::String(right)) => {
179 *left = right;
180 }
181
182 (left, right) if left.is_null() => *left = right,
184
185 (left, right) if right.is_null() => *left = right,
187
188 (left, right) => anyhow::bail!(
189 "Cannot merge a {} into a {}",
190 value_kind(&right),
191 value_kind(left),
192 ),
193 }
194
195 Ok(())
196}
197
198struct DynamicData {
199 version: Option<Ulid>,
200 merged: serde_json::Value,
201}
202
203pub struct PolicyFactory {
204 engine: Engine,
205 module: Module,
206 data: Data,
207 dynamic_data: ArcSwap<DynamicData>,
208 entrypoints: Entrypoints,
209}
210
211impl PolicyFactory {
212 #[tracing::instrument(name = "policy.load", skip(source))]
218 pub async fn load(
219 mut source: impl AsyncRead + std::marker::Unpin,
220 data: Data,
221 entrypoints: Entrypoints,
222 ) -> Result<Self, LoadError> {
223 let mut config = Config::default();
224 config.async_support(true);
225 config.cranelift_opt_level(OptLevel::SpeedAndSize);
226
227 let engine = Engine::new(&config).map_err(LoadError::Engine)?;
228
229 let mut buf = Vec::new();
231 source.read_to_end(&mut buf).await?;
232 let (engine, module) = tokio::task::spawn_blocking(move || {
234 let module = Module::new(&engine, buf)?;
235 anyhow::Ok((engine, module))
236 })
237 .await?
238 .map_err(LoadError::Compilation)?;
239
240 let merged = data.to_value().map_err(LoadError::InvalidData)?;
241 let dynamic_data = ArcSwap::new(Arc::new(DynamicData {
242 version: None,
243 merged,
244 }));
245
246 let factory = Self {
247 engine,
248 module,
249 data,
250 dynamic_data,
251 entrypoints,
252 };
253
254 factory
256 .instantiate()
257 .await
258 .map_err(LoadError::Instantiate)?;
259
260 Ok(factory)
261 }
262
263 pub async fn set_dynamic_data(
276 &self,
277 dynamic_data: mas_data_model::PolicyData,
278 ) -> Result<bool, LoadError> {
279 if self.dynamic_data.load().version == Some(dynamic_data.id) {
282 return Ok(false);
284 }
285
286 let static_data = self.data.to_value().map_err(LoadError::InvalidData)?;
287 let merged = merge_data(static_data, dynamic_data.data).map_err(LoadError::InvalidData)?;
288
289 self.instantiate_with_data(&merged)
291 .await
292 .map_err(LoadError::Instantiate)?;
293
294 self.dynamic_data.store(Arc::new(DynamicData {
296 version: Some(dynamic_data.id),
297 merged,
298 }));
299
300 Ok(true)
301 }
302
303 #[tracing::instrument(name = "policy.instantiate", skip_all)]
310 pub async fn instantiate(&self) -> Result<Policy, InstantiateError> {
311 let data = self.dynamic_data.load();
312 self.instantiate_with_data(&data.merged).await
313 }
314
315 async fn instantiate_with_data(
316 &self,
317 data: &serde_json::Value,
318 ) -> Result<Policy, InstantiateError> {
319 let mut store = Store::new(&self.engine, ());
320 let runtime = Runtime::new(&mut store, &self.module)
321 .await
322 .map_err(InstantiateError::Runtime)?;
323
324 let policy_entrypoints = runtime.entrypoints();
326
327 for e in self.entrypoints.all() {
328 if !policy_entrypoints.contains(e) {
329 return Err(InstantiateError::MissingEntrypoint {
330 entrypoint: e.to_owned(),
331 });
332 }
333 }
334
335 let instance = runtime
336 .with_data(&mut store, data)
337 .await
338 .map_err(InstantiateError::LoadData)?;
339
340 Ok(Policy {
341 store,
342 instance,
343 entrypoints: self.entrypoints.clone(),
344 })
345 }
346}
347
348pub struct Policy {
349 store: Store<()>,
350 instance: opa_wasm::Policy<opa_wasm::DefaultContext>,
351 entrypoints: Entrypoints,
352}
353
354#[derive(Debug, Error)]
355#[error("failed to evaluate policy")]
356pub enum EvaluationError {
357 Serialization(#[from] serde_json::Error),
358 Evaluation(#[from] anyhow::Error),
359}
360
361impl Policy {
362 #[tracing::instrument(
368 name = "policy.evaluate_email",
369 skip_all,
370 fields(
371 %input.email,
372 ),
373 )]
374 pub async fn evaluate_email(
375 &mut self,
376 input: EmailInput<'_>,
377 ) -> Result<EvaluationResult, EvaluationError> {
378 let [res]: [EvaluationResult; 1] = self
379 .instance
380 .evaluate(&mut self.store, &self.entrypoints.email, &input)
381 .await?;
382
383 Ok(res)
384 }
385
386 #[tracing::instrument(
392 name = "policy.evaluate.register",
393 skip_all,
394 fields(
395 ?input.registration_method,
396 input.username = input.username,
397 input.email = input.email,
398 ),
399 )]
400 pub async fn evaluate_register(
401 &mut self,
402 input: RegisterInput<'_>,
403 ) -> Result<EvaluationResult, EvaluationError> {
404 let [res]: [EvaluationResult; 1] = self
405 .instance
406 .evaluate(&mut self.store, &self.entrypoints.register, &input)
407 .await?;
408
409 Ok(res)
410 }
411
412 #[tracing::instrument(skip(self))]
418 pub async fn evaluate_client_registration(
419 &mut self,
420 input: ClientRegistrationInput<'_>,
421 ) -> Result<EvaluationResult, EvaluationError> {
422 let [res]: [EvaluationResult; 1] = self
423 .instance
424 .evaluate(
425 &mut self.store,
426 &self.entrypoints.client_registration,
427 &input,
428 )
429 .await?;
430
431 Ok(res)
432 }
433
434 #[tracing::instrument(
440 name = "policy.evaluate.authorization_grant",
441 skip_all,
442 fields(
443 %input.scope,
444 %input.client.id,
445 ),
446 )]
447 pub async fn evaluate_authorization_grant(
448 &mut self,
449 input: AuthorizationGrantInput<'_>,
450 ) -> Result<EvaluationResult, EvaluationError> {
451 let [res]: [EvaluationResult; 1] = self
452 .instance
453 .evaluate(
454 &mut self.store,
455 &self.entrypoints.authorization_grant,
456 &input,
457 )
458 .await?;
459
460 Ok(res)
461 }
462}
463
464#[cfg(test)]
465mod tests {
466
467 use std::time::SystemTime;
468
469 use super::*;
470
471 #[tokio::test]
472 async fn test_register() {
473 let data = Data::new("example.com".to_owned(), None).with_rest(serde_json::json!({
474 "allowed_domains": ["element.io", "*.element.io"],
475 "banned_domains": ["staging.element.io"],
476 }));
477
478 #[allow(clippy::disallowed_types)]
479 let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
480 .join("..")
481 .join("..")
482 .join("policies")
483 .join("policy.wasm");
484
485 let file = tokio::fs::File::open(path).await.unwrap();
486
487 let entrypoints = Entrypoints {
488 register: "register/violation".to_owned(),
489 client_registration: "client_registration/violation".to_owned(),
490 authorization_grant: "authorization_grant/violation".to_owned(),
491 email: "email/violation".to_owned(),
492 };
493
494 let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
495
496 let mut policy = factory.instantiate().await.unwrap();
497
498 let res = policy
499 .evaluate_register(RegisterInput {
500 registration_method: RegistrationMethod::Password,
501 username: "hello",
502 email: Some("hello@example.com"),
503 requester: Requester {
504 ip_address: None,
505 user_agent: None,
506 },
507 })
508 .await
509 .unwrap();
510 assert!(!res.valid());
511
512 let res = policy
513 .evaluate_register(RegisterInput {
514 registration_method: RegistrationMethod::Password,
515 username: "hello",
516 email: Some("hello@foo.element.io"),
517 requester: Requester {
518 ip_address: None,
519 user_agent: None,
520 },
521 })
522 .await
523 .unwrap();
524 assert!(res.valid());
525
526 let res = policy
527 .evaluate_register(RegisterInput {
528 registration_method: RegistrationMethod::Password,
529 username: "hello",
530 email: Some("hello@staging.element.io"),
531 requester: Requester {
532 ip_address: None,
533 user_agent: None,
534 },
535 })
536 .await
537 .unwrap();
538 assert!(!res.valid());
539 }
540
541 #[tokio::test]
542 async fn test_dynamic_data() {
543 let data = Data::new("example.com".to_owned(), None);
544
545 #[allow(clippy::disallowed_types)]
546 let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
547 .join("..")
548 .join("..")
549 .join("policies")
550 .join("policy.wasm");
551
552 let file = tokio::fs::File::open(path).await.unwrap();
553
554 let entrypoints = Entrypoints {
555 register: "register/violation".to_owned(),
556 client_registration: "client_registration/violation".to_owned(),
557 authorization_grant: "authorization_grant/violation".to_owned(),
558 email: "email/violation".to_owned(),
559 };
560
561 let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
562
563 let mut policy = factory.instantiate().await.unwrap();
564
565 let res = policy
566 .evaluate_register(RegisterInput {
567 registration_method: RegistrationMethod::Password,
568 username: "hello",
569 email: Some("hello@example.com"),
570 requester: Requester {
571 ip_address: None,
572 user_agent: None,
573 },
574 })
575 .await
576 .unwrap();
577 assert!(res.valid());
578
579 factory
581 .set_dynamic_data(mas_data_model::PolicyData {
582 id: Ulid::nil(),
583 created_at: SystemTime::now().into(),
584 data: serde_json::json!({
585 "emails": {
586 "banned_addresses": {
587 "substrings": ["hello"]
588 }
589 }
590 }),
591 })
592 .await
593 .unwrap();
594 let mut policy = factory.instantiate().await.unwrap();
595 let res = policy
596 .evaluate_register(RegisterInput {
597 registration_method: RegistrationMethod::Password,
598 username: "hello",
599 email: Some("hello@example.com"),
600 requester: Requester {
601 ip_address: None,
602 user_agent: None,
603 },
604 })
605 .await
606 .unwrap();
607 assert!(!res.valid());
608 }
609
610 #[tokio::test]
611 async fn test_big_dynamic_data() {
612 let data = Data::new("example.com".to_owned(), None);
613
614 #[allow(clippy::disallowed_types)]
615 let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
616 .join("..")
617 .join("..")
618 .join("policies")
619 .join("policy.wasm");
620
621 let file = tokio::fs::File::open(path).await.unwrap();
622
623 let entrypoints = Entrypoints {
624 register: "register/violation".to_owned(),
625 client_registration: "client_registration/violation".to_owned(),
626 authorization_grant: "authorization_grant/violation".to_owned(),
627 email: "email/violation".to_owned(),
628 };
629
630 let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
631
632 let data: Vec<String> = (0..(1024 * 1024 / 8))
635 .map(|i| format!("{:05}", i % 100_000))
636 .collect();
637 let json = serde_json::json!({ "emails": { "banned_addresses": { "substrings": data } } });
638 factory
639 .set_dynamic_data(mas_data_model::PolicyData {
640 id: Ulid::nil(),
641 created_at: SystemTime::now().into(),
642 data: json,
643 })
644 .await
645 .unwrap();
646
647 let mut policy = factory.instantiate().await.unwrap();
650 let res = policy
651 .evaluate_register(RegisterInput {
652 registration_method: RegistrationMethod::Password,
653 username: "hello",
654 email: Some("12345@example.com"),
655 requester: Requester {
656 ip_address: None,
657 user_agent: None,
658 },
659 })
660 .await
661 .unwrap();
662 assert!(!res.valid());
663 }
664
665 #[test]
666 fn test_merge() {
667 use serde_json::json as j;
668
669 let res = merge_data(j!({"hello": "world"}), j!({"foo": "bar"})).unwrap();
671 assert_eq!(res, j!({"hello": "world", "foo": "bar"}));
672
673 let res = merge_data(j!({"hello": "world"}), j!({"hello": "john"})).unwrap();
675 assert_eq!(res, j!({"hello": "john"}));
676
677 let res = merge_data(j!({"hello": true}), j!({"hello": false})).unwrap();
678 assert_eq!(res, j!({"hello": false}));
679
680 let res = merge_data(j!({"hello": 0}), j!({"hello": 42})).unwrap();
681 assert_eq!(res, j!({"hello": 42}));
682
683 merge_data(j!({"hello": "world"}), j!({"hello": 123}))
685 .expect_err("Can't merge different types");
686
687 let res = merge_data(j!({"hello": ["world"]}), j!({"hello": ["john"]})).unwrap();
689 assert_eq!(res, j!({"hello": ["world", "john"]}));
690
691 let res = merge_data(j!({"hello": "world"}), j!({"hello": null})).unwrap();
693 assert_eq!(res, j!({"hello": null}));
694
695 let res = merge_data(j!({"hello": null}), j!({"hello": "world"})).unwrap();
697 assert_eq!(res, j!({"hello": "world"}));
698
699 let res = merge_data(j!({"a": {"b": {"c": "d"}}}), j!({"a": {"b": {"e": "f"}}})).unwrap();
701 assert_eq!(res, j!({"a": {"b": {"c": "d", "e": "f"}}}));
702 }
703}