upload scripts

This commit is contained in:
Quaternions 2024-12-18 16:09:29 -08:00
parent 79ba77e7d9
commit 6f3d3b170d

View File

@ -1,13 +1,12 @@
use std::{io::{Read, Seek}, path::PathBuf}; use std::{io::{Read, Seek}, path::PathBuf};
use clap::{Args, Parser, Subcommand}; use clap::{Args, Parser, Subcommand};
use anyhow::Result as AResult; use anyhow::Result as AResult;
use futures::{StreamExt,TryStreamExt};
#[derive(Parser)] #[derive(Parser)]
#[command(author,version,about,long_about=None)] #[command(author,version,about,long_about=None)]
#[command(propagate_version=true)] #[command(propagate_version=true)]
struct Cli{ struct Cli{
#[arg(long)]
path:Option<PathBuf>,
#[command(subcommand)] #[command(subcommand)]
command:Commands, command:Commands,
} }
@ -19,6 +18,7 @@ enum Commands {
Replace, Replace,
Scan, Scan,
Upload, Upload,
UploadScripts(UploadScriptsCommand)
} }
#[derive(Args)] #[derive(Args)]
@ -31,7 +31,14 @@ struct MapList {
maps: Vec<u64>, maps: Vec<u64>,
} }
fn main() -> AResult<()> { #[derive(Args)]
struct UploadScriptsCommand{
#[arg(long)]
session_id:PathBuf,
}
#[tokio::main]
async fn main() -> AResult<()> {
let cli = Cli::parse(); let cli = Cli::parse();
match cli.command { match cli.command {
Commands::ExtractScripts(pathlist)=>extract_scripts(pathlist.paths), Commands::ExtractScripts(pathlist)=>extract_scripts(pathlist.paths),
@ -39,6 +46,7 @@ fn main() -> AResult<()> {
Commands::Replace=>replace(), Commands::Replace=>replace(),
Commands::Scan=>scan(), Commands::Scan=>scan(),
Commands::Upload=>upload(), Commands::Upload=>upload(),
Commands::UploadScripts(command)=>upload_scripts(command.session_id).await,
} }
} }
@ -753,3 +761,144 @@ fn interactive() -> AResult<()>{
std::fs::write("id",id.to_string())?; std::fs::write("id",id.to_string())?;
Ok(()) Ok(())
} }
fn hash_source(source:&str)->u64{
let mut hasher=siphasher::sip::SipHasher::new();
std::hash::Hasher::write(&mut hasher,source.as_bytes());
std::hash::Hasher::finish(&hasher)
}
fn hash_format(hash:u64)->String{
format!("{:016x}",hash)
}
type GOCError=submissions_api::types::SingleItemError;
type GOCResult=Result<submissions_api::types::ScriptID,GOCError>;
async fn get_or_create_script(api:&submissions_api::external::Context,source:&str)->GOCResult{
let script_response=api.get_script_from_hash(submissions_api::types::HashRequest{
hash:hash_format(hash_source(source)).as_str(),
}).await?;
Ok(match script_response{
Some(script_response)=>script_response.ID,
None=>api.create_script(submissions_api::types::CreateScriptRequest{
Name:"Script",
Source:source,
SubmissionID:None,
}).await.map_err(GOCError::Other)?.ID
})
}
async fn check_or_create_script_poicy(
api:&submissions_api::external::Context,
hash:&str,
script_policy:submissions_api::types::CreateScriptPolicyRequest,
)->Result<(),GOCError>{
let script_policy_result=api.get_script_policy_from_hash(submissions_api::types::HashRequest{
hash,
}).await?;
match script_policy_result{
Some(script_policy_reponse)=>{
// check that everything matches the expectation
assert!(hash==script_policy_reponse.FromScriptHash);
assert!(script_policy.ToScriptID==script_policy_reponse.ToScriptID);
assert!(script_policy.Policy==script_policy_reponse.Policy);
},
None=>{
// create a new policy
api.create_script_policy(script_policy).await.map_err(GOCError::Other)?;
}
}
Ok(())
}
async fn do_policy(
api:&submissions_api::external::Context,
script_ids:&std::collections::HashMap<&str,submissions_api::types::ScriptID>,
source:&str,
to_script_id:submissions_api::types::ScriptID,
policy:submissions_api::types::Policy,
)->Result<(),GOCError>{
let hash=hash_format(hash_source(source));
check_or_create_script_poicy(api,hash.as_str(),submissions_api::types::CreateScriptPolicyRequest{
FromScriptID:script_ids[source],
ToScriptID:to_script_id,
Policy:policy,
}).await
}
async fn upload_scripts(session_id:PathBuf)->AResult<()>{
let cookie={
let mut cookie=String::new();
std::fs::File::open(session_id)?.read_to_string(&mut cookie)?;
submissions_api::Cookie::new(&cookie)?
};
let api=&submissions_api::external::Context::new("http://localhost:8082".to_owned(),cookie)?;
let allowed_set=get_allowed_set()?;
let allowed_map=get_allowed_map()?;
let replace_map=get_replace_map()?;
let blocked=get_blocked()?;
// create a unified deduplicated set of all scripts
let script_set:std::collections::HashSet<&str>=allowed_set.iter()
.map(|s|s.as_str())
.chain(
replace_map.keys().map(|s|s.as_str())
).chain(
blocked.iter().map(|s|s.as_str())
).collect();
// get or create every unique script
let script_ids:std::collections::HashMap<&str,submissions_api::types::ScriptID>=
futures::stream::iter(script_set)
.map(|source|async move{
let script_id=get_or_create_script(api,source).await?;
Ok::<_,GOCError>((source,script_id))
})
.buffer_unordered(16)
.try_collect().await?;
// get or create policy for each script in each category
//
// replace
futures::stream::iter(replace_map.iter().map(Ok))
.try_for_each_concurrent(Some(16),|(source,id)|async{
do_policy(
api,
&script_ids,
source,
script_ids[allowed_map[id].as_str()],
submissions_api::types::Policy::Replace
).await
}).await?;
// allowed
futures::stream::iter(allowed_set.iter().map(Ok))
.try_for_each_concurrent(Some(16),|source|async{
do_policy(
api,
&script_ids,
source,
script_ids[source.as_str()],
submissions_api::types::Policy::Allowed
).await
}).await?;
// blocked
futures::stream::iter(blocked.iter().map(Ok))
.try_for_each_concurrent(Some(16),|source|async{
do_policy(
api,
&script_ids,
source,
script_ids[source.as_str()],
submissions_api::types::Policy::Blocked
).await
}).await?;
Ok(())
}