2024-12-26 19:44:44 -08:00

420 lines
12 KiB
Rust

use clap::{Args,Parser,Subcommand};
use futures::{StreamExt,TryStreamExt};
const READ_CONCURRENCY:usize=16;
const REMOTE_CONCURRENCY:usize=16;
#[derive(Parser)]
#[command(author,version,about,long_about=None)]
#[command(propagate_version=true)]
struct Cli{
#[command(subcommand)]
command:Commands,
}
#[derive(Subcommand)]
enum Commands{
Review(ReviewCommand),
UploadScripts(UploadScriptsCommand),
}
#[derive(Args)]
struct ReviewCommand{
#[arg(long)]
session_id:String,
#[arg(long)]
api_url:String,
}
#[derive(Args)]
struct UploadScriptsCommand{
#[arg(long)]
session_id:String,
#[arg(long)]
api_url:String,
}
#[tokio::main]
async fn main(){
let cli=Cli::parse();
match cli.command{
Commands::Review(command)=>review(ReviewConfig{
session_id:command.session_id,
api_url:command.api_url,
}).await.unwrap(),
Commands::UploadScripts(command)=>upload_scripts(UploadConfig{
session_id:command.session_id,
api_url:command.api_url,
}).await.unwrap(),
}
}
enum ScriptActionParseResult{
Pass,
Block,
Exit,
Delete,
}
struct ParseScriptActionErr;
impl std::str::FromStr for ScriptActionParseResult{
type Err=ParseScriptActionErr;
fn from_str(s:&str)->Result<Self,Self::Err>{
match s{
"pass\n"=>Ok(Self::Pass),
"block\n"=>Ok(Self::Block),
"exit\n"=>Ok(Self::Exit),
"delete\n"=>Ok(Self::Delete),
_=>Err(ParseScriptActionErr),
}
}
}
#[allow(dead_code)]
#[derive(Debug)]
enum ReviewError{
Cookie(submissions_api::CookieError),
Reqwest(submissions_api::ReqwestError),
GetPolicies(submissions_api::Error),
GetScriptFromHash(submissions_api::types::SingleItemError),
NoScript,
WriteCurrent(std::io::Error),
ActionIO(std::io::Error),
ReadCurrent(std::io::Error),
DeduplicateModified(submissions_api::types::SingleItemError),
UploadModified(submissions_api::Error),
UpdateScriptPolicy(submissions_api::Error),
}
struct ReviewConfig{
session_id:String,
api_url:String,
}
async fn review(config:ReviewConfig)->Result<(),ReviewError>{
// download unreviewed policies
// review them
let cookie=submissions_api::Cookie::new(&config.session_id).map_err(ReviewError::Cookie)?;
let api=submissions_api::external::Context::new(config.api_url,cookie).map_err(ReviewError::Reqwest)?;
let unreviewed_policies=api.get_script_policies(submissions_api::types::GetScriptPoliciesRequest{
Page:1,
Limit:100,
FromScriptHash:None,
ToScriptID:None,
Policy:Some(submissions_api::types::Policy::None),
}).await.map_err(ReviewError::GetPolicies)?;
for unreviewed_policy in unreviewed_policies{
// download source code
let script_response=api.get_script_from_hash(submissions_api::types::HashRequest{
hash:unreviewed_policy.FromScriptHash.as_str(),
}).await
.map_err(ReviewError::GetScriptFromHash)?
.ok_or(ReviewError::NoScript)?;
let source=script_response.Source;
//load source into current.lua
tokio::fs::write("current.lua",source.as_str()).await.map_err(ReviewError::WriteCurrent)?;
//prompt action in terminal
//wait for input
let script_action;
loop{
print!("action: ");
std::io::Write::flush(&mut std::io::stdout()).map_err(ReviewError::ActionIO)?;
let mut action_string=String::new();
std::io::stdin().read_line(&mut action_string).map_err(ReviewError::ActionIO)?;
if let Ok(parsed_script_action)=action_string.parse::<ScriptActionParseResult>(){
script_action=parsed_script_action;
break;
}
}
// default to_script_id is from from_script_id (only changed for replace policy)
let mut to_script_id=None;
// interpret action
let reviewed_policy=match script_action{
ScriptActionParseResult::Pass=>{
//if current.lua was updated, create an allowed and replace file and set script_action to replace(new_id)
let modified_source=tokio::fs::read_to_string("current.lua").await.map_err(ReviewError::ReadCurrent)?;
if modified_source==source{
submissions_api::types::Policy::Allowed
}else{
// compute hash
let mut hasher=siphasher::sip::SipHasher::new();
std::hash::Hasher::write(&mut hasher,source.as_bytes());
let hash=std::hash::Hasher::finish(&hasher);
// check if modified script already exists
let maybe_script_response=api.get_script_from_hash(submissions_api::types::HashRequest{
hash:format!("{:016x}",hash).as_str(),
}).await.map_err(ReviewError::DeduplicateModified)?;
// write to_script_id, uploading modified script if necessary
to_script_id=Some(match maybe_script_response{
Some(script_response)=>script_response.ID,
None=>api.create_script(submissions_api::types::CreateScriptRequest{
Name:script_response.Name.as_str(),
Source:modified_source.as_str(),
SubmissionID:Some(script_response.SubmissionID),
}).await.map_err(ReviewError::UploadModified)?.ID
});
// use replace policy
submissions_api::types::Policy::Replace
}
},
ScriptActionParseResult::Block=>submissions_api::types::Policy::Blocked,
ScriptActionParseResult::Exit=>break,
ScriptActionParseResult::Delete=>submissions_api::types::Policy::Delete,
};
// update policy
api.update_script_policy(submissions_api::types::UpdateScriptPolicyRequest{
ScriptPolicyID:unreviewed_policy.ID,
FromScriptID:None,
ToScriptID:to_script_id,
Policy:Some(reviewed_policy),
}).await.map_err(ReviewError::UpdateScriptPolicy)?;
}
Ok(())
}
#[allow(dead_code)]
#[derive(Debug)]
enum ScriptUploadError{
Cookie(submissions_api::CookieError),
Reqwest(submissions_api::ReqwestError),
AllowedSet(std::io::Error),
AllowedMap(GetMapError),
ReplaceMap(GetMapError),
BlockedSet(std::io::Error),
GOC(GOCError),
GOCPolicyReplace(GOCError),
GOCPolicyAllowed(GOCError),
GOCPolicyBlocked(GOCError),
}
fn read_dir_stream(dir:tokio::fs::ReadDir)->impl futures::stream::Stream<Item=std::io::Result<tokio::fs::DirEntry>>{
futures::stream::unfold(dir,|mut dir|async{
match dir.next_entry().await{
Ok(Some(entry))=>Some((Ok(entry),dir)),
Ok(None)=>None, // End of directory
Err(e)=>Some((Err(e),dir)), // Error encountered
}
})
}
async fn get_set_from_file(path:impl AsRef<std::path::Path>)->std::io::Result<std::collections::HashSet<String>>{
read_dir_stream(tokio::fs::read_dir(path).await?)
.map(|dir_entry|async{
tokio::fs::read_to_string(dir_entry?.path()).await
})
.buffer_unordered(READ_CONCURRENCY)
.try_collect().await
}
async fn get_allowed_set()->std::io::Result<std::collections::HashSet<String>>{
get_set_from_file("scripts/allowed").await
}
async fn get_blocked_set()->std::io::Result<std::collections::HashSet<String>>{
get_set_from_file("scripts/blocked").await
}
#[allow(dead_code)]
#[derive(Debug)]
enum GetMapError{
IO(std::io::Error),
FileStem,
ToStr,
ParseInt(std::num::ParseIntError),
}
async fn get_allowed_map()->Result<std::collections::HashMap::<u32,String>,GetMapError>{
read_dir_stream(tokio::fs::read_dir("scripts/allowed").await.map_err(GetMapError::IO)?)
.map(|dir_entry|async{
let path=dir_entry.map_err(GetMapError::IO)?.path();
let id:u32=path.file_stem()
.ok_or(GetMapError::FileStem)?
.to_str()
.ok_or(GetMapError::ToStr)?
.parse().map_err(GetMapError::ParseInt)?;
let source=tokio::fs::read_to_string(path).await.map_err(GetMapError::IO)?;
Ok((id,source))
})
.buffer_unordered(READ_CONCURRENCY)
.try_collect().await
}
async fn get_replace_map()->Result<std::collections::HashMap::<String,u32>,GetMapError>{
read_dir_stream(tokio::fs::read_dir("scripts/replace").await.map_err(GetMapError::IO)?)
.map(|dir_entry|async{
let path=dir_entry.map_err(GetMapError::IO)?.path();
let id:u32=path.file_stem()
.ok_or(GetMapError::FileStem)?
.to_str()
.ok_or(GetMapError::ToStr)?
.parse().map_err(GetMapError::ParseInt)?;
let source=tokio::fs::read_to_string(path).await.map_err(GetMapError::IO)?;
Ok((source,id))
})
.buffer_unordered(READ_CONCURRENCY)
.try_collect().await
}
fn hash_source(source:&str)->u64{
let mut hasher=siphasher::sip::SipHasher::new();
std::hash::Hasher::write(&mut hasher,source.as_bytes());
std::hash::Hasher::finish(&hasher)
}
fn hash_format(hash:u64)->String{
format!("{:016x}",hash)
}
type GOCError=submissions_api::types::SingleItemError;
type GOCResult=Result<submissions_api::types::ScriptID,GOCError>;
async fn get_or_create_script(api:&submissions_api::external::Context,source:&str)->GOCResult{
let script_response=api.get_script_from_hash(submissions_api::types::HashRequest{
hash:hash_format(hash_source(source)).as_str(),
}).await?;
Ok(match script_response{
Some(script_response)=>script_response.ID,
None=>api.create_script(submissions_api::types::CreateScriptRequest{
Name:"Script",
Source:source,
SubmissionID:None,
}).await.map_err(GOCError::Other)?.ID
})
}
async fn check_or_create_script_poicy(
api:&submissions_api::external::Context,
hash:&str,
script_policy:submissions_api::types::CreateScriptPolicyRequest,
)->Result<(),GOCError>{
let script_policy_result=api.get_script_policy_from_hash(submissions_api::types::HashRequest{
hash,
}).await?;
match script_policy_result{
Some(script_policy_reponse)=>{
// check that everything matches the expectation
assert!(hash==script_policy_reponse.FromScriptHash);
assert!(script_policy.ToScriptID==script_policy_reponse.ToScriptID);
assert!(script_policy.Policy==script_policy_reponse.Policy);
},
None=>{
// create a new policy
api.create_script_policy(script_policy).await.map_err(GOCError::Other)?;
}
}
Ok(())
}
async fn do_policy(
api:&submissions_api::external::Context,
script_ids:&std::collections::HashMap<&str,submissions_api::types::ScriptID>,
source:&str,
to_script_id:submissions_api::types::ScriptID,
policy:submissions_api::types::Policy,
)->Result<(),GOCError>{
let hash=hash_format(hash_source(source));
check_or_create_script_poicy(api,hash.as_str(),submissions_api::types::CreateScriptPolicyRequest{
FromScriptID:script_ids[source],
ToScriptID:to_script_id,
Policy:policy,
}).await
}
struct UploadConfig{
session_id:String,
api_url:String,
}
async fn upload_scripts(config:UploadConfig)->Result<(),ScriptUploadError>{
let cookie=submissions_api::Cookie::new(&config.session_id).map_err(ScriptUploadError::Cookie)?;
let api=&submissions_api::external::Context::new(config.api_url,cookie).map_err(ScriptUploadError::Reqwest)?;
let (
allowed_set_result,
allowed_map_result,
replace_map_result,
blocked_set_result,
)=tokio::join!(
get_allowed_set(),
get_allowed_map(),
get_replace_map(),
get_blocked_set(),
);
let allowed_set=allowed_set_result.map_err(ScriptUploadError::AllowedSet)?;
let allowed_map=allowed_map_result.map_err(ScriptUploadError::AllowedMap)?;
let replace_map=replace_map_result.map_err(ScriptUploadError::ReplaceMap)?;
let blocked_set=blocked_set_result.map_err(ScriptUploadError::BlockedSet)?;
// create a unified deduplicated set of all scripts
let script_set:std::collections::HashSet<&str>=allowed_set.iter()
.map(String::as_str)
.chain(
replace_map.keys().map(String::as_str)
).chain(
blocked_set.iter().map(String::as_str)
).collect();
// get or create every unique script
let script_ids:std::collections::HashMap<&str,submissions_api::types::ScriptID>=
futures::stream::iter(script_set)
.map(|source|async move{
let script_id=get_or_create_script(api,source).await?;
Ok((source,script_id))
})
.buffer_unordered(REMOTE_CONCURRENCY)
.try_collect().await.map_err(ScriptUploadError::GOC)?;
// get or create policy for each script in each category
//
// replace
let replace_fut=futures::stream::iter(replace_map.iter().map(Ok))
.try_for_each_concurrent(Some(REMOTE_CONCURRENCY),|(source,id)|async{
do_policy(
api,
&script_ids,
source,
script_ids[allowed_map[id].as_str()],
submissions_api::types::Policy::Replace
).await.map_err(ScriptUploadError::GOCPolicyReplace)
});
// allowed
let allowed_fut=futures::stream::iter(allowed_set.iter().map(Ok))
.try_for_each_concurrent(Some(REMOTE_CONCURRENCY),|source|async{
do_policy(
api,
&script_ids,
source,
script_ids[source.as_str()],
submissions_api::types::Policy::Allowed
).await.map_err(ScriptUploadError::GOCPolicyAllowed)
});
// blocked
let blocked_fut=futures::stream::iter(blocked_set.iter().map(Ok))
.try_for_each_concurrent(Some(REMOTE_CONCURRENCY),|source|async{
do_policy(
api,
&script_ids,
source,
script_ids[source.as_str()],
submissions_api::types::Policy::Blocked
).await.map_err(ScriptUploadError::GOCPolicyBlocked)
});
tokio::try_join!(replace_fut,allowed_fut,blocked_fut)?;
Ok(())
}