962 lines
33 KiB
Rust
962 lines
33 KiB
Rust
//TODO include [Y/n] confimrmation for API-calling commands and -y flag
|
|
//TODO Use tracing crate
|
|
//TODO Implement structured output for Gemini and Perplexity
|
|
//TODO If parsing the response fails in prettify(), try to parse as {"error":"msg"}
|
|
//TODO Print Token usage in dbg print or somehting
|
|
//TODO Compile .typ from code
|
|
//TODO Retry after failing
|
|
use clap::{Parser, Subcommand, ValueEnum};
|
|
use serde::{Deserialize, Serialize};
|
|
use std::{
|
|
collections::HashMap,
|
|
env,
|
|
fmt::{Debug, Display, Write as _},
|
|
fs,
|
|
path::{Path, PathBuf},
|
|
process::ExitCode,
|
|
str::FromStr,
|
|
};
|
|
use tokio::time;
|
|
use inline_colorization::*;
|
|
|
|
#[derive(Subcommand)]
|
|
enum Task {
|
|
/// Parses a given input file (.csv) and lists the contained companies. Panics if a `name` column does not exist.
|
|
List {
|
|
#[arg(short, long)]
|
|
csv_file: String,
|
|
},
|
|
/// Parses a given input file (.json) and counts the number of occurences per company as a reference client of other companies.
|
|
/// Expects the json file structure given by the LLM prompts
|
|
/// (TODO: Not quite. If there are multiple prompt results in a single file, concatination of multiple results has to be done manually at the moment. Thats why we like helix tho.)
|
|
CountPartners {
|
|
/// The json file, containing the required format.
|
|
#[arg(short, long)]
|
|
json_file: String,
|
|
},
|
|
/// WARNING: This might cost money to run!
|
|
/// Tests a given system prompt on a given model with the context prompt ""Research Target Company: Pandaloop\n Website: https://pandaloop.de"
|
|
TestWithPandaloop {
|
|
#[arg(short, long)]
|
|
model: Model,
|
|
#[arg(short, long)]
|
|
system_prompt_file: String,
|
|
#[arg(short, long)]
|
|
json_schema_file: Option<String>,
|
|
#[arg(short, long)]
|
|
pretty: bool,
|
|
},
|
|
/// WARNING: This might cost money to run!
|
|
SinglePrompt {
|
|
idx: usize,
|
|
#[arg(short, long)]
|
|
csv_file: String,
|
|
#[arg(short, long)]
|
|
model: Model,
|
|
#[arg(short, long)]
|
|
system_prompt_file: String,
|
|
#[arg(short, long)]
|
|
json_schema_file: Option<String>,
|
|
#[arg(short, long)]
|
|
pretty: bool,
|
|
},
|
|
/// WARNING: This might cost money to run!
|
|
/// Range boundaries are inclusive
|
|
RangePrompt {
|
|
range_start: usize,
|
|
range_end: usize,
|
|
#[arg(short, long)]
|
|
csv_file: String,
|
|
#[arg(short, long)]
|
|
model: Model,
|
|
#[arg(short, long)]
|
|
system_prompt_file: String,
|
|
#[arg(short, long)]
|
|
json_schema_file: Option<String>,
|
|
#[arg(short, long)]
|
|
pretty: bool,
|
|
},
|
|
/// WARNING: This might cost A LOT money to run!
|
|
/// Range boundaries are inclusive
|
|
MultiPrompt {
|
|
range_start: usize,
|
|
range_end: usize,
|
|
range_step: usize,
|
|
#[arg(short, long)]
|
|
csv_file: String,
|
|
#[arg(short, long)]
|
|
model: Model,
|
|
#[arg(short, long)]
|
|
system_prompt_file: String,
|
|
#[arg(short, long)]
|
|
json_schema_file: Option<String>,
|
|
#[arg(short, long)]
|
|
pretty: bool,
|
|
},
|
|
/// Expects JSON format from multiprompt with Lead evaluation system prompt
|
|
/// Formats the recieved almost-JSON into valid JSON and verifies
|
|
/// output_format = Typst: Formats leads into human-readable .typ file
|
|
ProcessLeadEvalJson {
|
|
#[arg(short, long)]
|
|
json_file: String,
|
|
#[arg(short, long)]
|
|
min_attractiveness: usize,
|
|
#[arg(short, long)]
|
|
output_format: OutputFormat,
|
|
},
|
|
}
|
|
|
|
#[derive(thiserror::Error, Debug)]
|
|
enum Error {
|
|
#[error("{style_bold}{color_red}API Response Error: {style_reset}{color_reset}\n{0}")]
|
|
APIResponseError(String),
|
|
#[error("{style_bold}{color_red}Unexpected response structure Error: {style_reset}{color_reset}\n{0}")]
|
|
UnexpectedResponseStructure(String),
|
|
#[error("{style_bold}{color_red}context index range is invalid: {style_reset}{color_reset}\n{0}")]
|
|
InvalidRange(String),
|
|
#[error("{style_bold}{color_red}api key does not exist in .env file: {style_reset}{color_reset}\n{0}")]
|
|
ApiKeyNotFound(Model),
|
|
#[error("{style_bold}{color_red}Error parsing{filename:?} as csv file: {style_reset}{color_reset}\n{source}")]
|
|
CsvFile {
|
|
#[source]
|
|
source: csv::Error,
|
|
filename: Option<PathBuf>,
|
|
},
|
|
#[error("{style_bold}{color_red}Error parsing json: {style_reset}{color_reset}\n{0}")]
|
|
SerdeJson(#[from] serde_json::Error),
|
|
#[error("{style_bold}{color_red}Error parsing {filename:?} as json: {style_reset}{color_reset}\n{source}")]
|
|
SerdeJsonFile {
|
|
#[source]
|
|
source: serde_json::Error,
|
|
filename: Option<PathBuf>,
|
|
},
|
|
#[error("{style_bold}{color_red}reqwest Error: {style_reset}{color_reset}\n{0}")]
|
|
Reqwest(#[from] reqwest::Error),
|
|
#[error("{style_bold}{color_red}io Error: {style_reset}{color_reset}\n{0}")]
|
|
Io(#[from] std::io::Error),
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
struct CsvCompany {
|
|
#[serde(rename = "Company Name")]
|
|
name: String,
|
|
#[serde(rename = "Website")]
|
|
website: String,
|
|
#[serde(rename = "Company City")]
|
|
city: String,
|
|
#[serde(rename = "# Employees")]
|
|
employee_count: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
struct FindPartnersTargetCompany {
|
|
#[serde(rename = "target_company")]
|
|
name: String,
|
|
connections: Vec<FindPartnersPartnerCompany>,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
struct FindPartnersPartnerCompany {
|
|
category: String,
|
|
context: String,
|
|
partner_name: String,
|
|
people: Vec<String>,
|
|
source_type: String,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
struct EvalLeadsContact {
|
|
value: String,
|
|
#[serde(rename = "type")]
|
|
value_type: String,
|
|
category: String,
|
|
source_url: String,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
struct EvalLeadsEmployee {
|
|
name: String,
|
|
role: String,
|
|
email: String,
|
|
phone: String,
|
|
source_url: String,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
struct EvalLeadsLead {
|
|
company_name: String,
|
|
description: String,
|
|
employee_count: String,
|
|
employees: Vec<EvalLeadsEmployee>,
|
|
general_contacts: Vec<EvalLeadsContact>,
|
|
industry: String,
|
|
lead_attractiveness_score: usize,
|
|
scoring_reasoning: String,
|
|
website: String,
|
|
}
|
|
|
|
#[derive(ValueEnum, PartialEq, Clone, Copy, Default, Debug, Serialize)]
|
|
enum Model {
|
|
#[default]
|
|
Gemini,
|
|
OpenAI,
|
|
Perplexity,
|
|
}
|
|
|
|
impl Display for Model {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
match self {
|
|
Model::Gemini => write!(f, "Google Gemini: 2.5"),
|
|
Model::Perplexity => write!(f, "Perplexity AI: Sonar"),
|
|
Model::OpenAI => write!(f, "OpenAI: GPT-5"),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(ValueEnum, Clone, Copy, Default, Debug, Serialize)]
|
|
enum OutputFormat {
|
|
#[default]
|
|
Json,
|
|
Typst,
|
|
}
|
|
|
|
struct ApiKeys {
|
|
gemini_api_key: String,
|
|
openai_api_key: String,
|
|
perplexity_api_key: String,
|
|
}
|
|
|
|
#[derive(Parser)]
|
|
struct Args {
|
|
#[command(subcommand)]
|
|
task: Task,
|
|
}
|
|
|
|
fn extract_response(content: &str, model: Model) -> Result<String, Error> {
|
|
let ser_all: serde_json::Value = serde_json::from_str(content)?;
|
|
let cont = match model {
|
|
Model::Gemini => ser_all
|
|
.get("candidates")
|
|
.ok_or(Error::UnexpectedResponseStructure(content.to_owned()))?
|
|
.get(0)
|
|
.ok_or(Error::UnexpectedResponseStructure(content.to_owned()))?
|
|
.get("content")
|
|
.ok_or(Error::UnexpectedResponseStructure(content.to_owned()))?
|
|
.get("parts")
|
|
.ok_or(Error::UnexpectedResponseStructure(content.to_owned()))?
|
|
.get(0)
|
|
.ok_or(Error::UnexpectedResponseStructure(content.to_owned()))?
|
|
.get("text")
|
|
.ok_or(Error::UnexpectedResponseStructure(content.to_owned()))?
|
|
.as_str()
|
|
.unwrap()
|
|
.replace("\\\"", "\"")
|
|
.replace("\\n", "\n"),
|
|
Model::OpenAI => ser_all
|
|
.get("output")
|
|
.ok_or(Error::UnexpectedResponseStructure(content.to_owned()))?
|
|
.as_array()
|
|
.and_then(|arr| arr.last())
|
|
.ok_or(Error::UnexpectedResponseStructure(content.to_owned()))?
|
|
.get("content")
|
|
.ok_or(Error::UnexpectedResponseStructure(content.to_owned()))?
|
|
.get(0)
|
|
.ok_or(Error::UnexpectedResponseStructure(content.to_owned()))?
|
|
.get("text")
|
|
.ok_or(Error::UnexpectedResponseStructure(content.to_owned()))?
|
|
.as_str()
|
|
.ok_or(Error::UnexpectedResponseStructure(content.to_owned()))?
|
|
.to_owned(),
|
|
Model::Perplexity => {
|
|
let raw_msg = ser_all["choices"][0]["message"]["content"]
|
|
.as_str()
|
|
.ok_or(Error::UnexpectedResponseStructure(content.to_owned()))?
|
|
.replace("\\\"", "\"")
|
|
.replace("\\n", "\n");
|
|
let pref_stripped_msg = raw_msg.strip_prefix("```json").unwrap_or(&raw_msg);
|
|
let bidir_stripped_msg = pref_stripped_msg
|
|
.strip_suffix("```")
|
|
.unwrap_or(&pref_stripped_msg);
|
|
bidir_stripped_msg.trim().to_owned()
|
|
}
|
|
};
|
|
Ok(cont)
|
|
}
|
|
|
|
/// Validates that the response is any kind of json. Does not validate structure of json.
|
|
fn validate_json(cont: &str) -> Result<String, Error> {
|
|
let ser_cont: serde_json::Value =
|
|
serde_json::from_str(&cont).map_err(|source| Error::SerdeJson(source))?;
|
|
let cont_pretty = serde_json::to_string_pretty(&ser_cont)?;
|
|
Ok(cont_pretty)
|
|
}
|
|
|
|
fn parse_csv(file_path: &Path) -> Result<Vec<CsvCompany>, Error> {
|
|
let mut reader = csv::ReaderBuilder::new()
|
|
.has_headers(true)
|
|
.from_path(&file_path)
|
|
.map_err(|source| Error::CsvFile {
|
|
source,
|
|
filename: Some(file_path.to_owned()),
|
|
})?;
|
|
let mut companies = Vec::new();
|
|
for res in reader.deserialize() {
|
|
let record: CsvCompany = res.map_err(|source| Error::CsvFile {
|
|
source,
|
|
filename: Some(file_path.to_owned()),
|
|
})?;
|
|
companies.push(record);
|
|
}
|
|
Ok(companies)
|
|
}
|
|
|
|
async fn prompt(
|
|
model: Model,
|
|
reqclient: &reqwest::Client,
|
|
api_keys: &ApiKeys,
|
|
system_prompt: &str,
|
|
context_prompt: &str,
|
|
schema: &Option<String>,
|
|
) -> Result<reqwest::Response, reqwest::Error> {
|
|
match model {
|
|
Model::Gemini => {
|
|
gemini_prompt(
|
|
reqclient,
|
|
&api_keys.gemini_api_key,
|
|
system_prompt,
|
|
context_prompt,
|
|
)
|
|
.await
|
|
}
|
|
Model::OpenAI => {
|
|
openai_prompt(
|
|
reqclient,
|
|
&api_keys.openai_api_key,
|
|
system_prompt,
|
|
context_prompt,
|
|
schema,
|
|
)
|
|
.await
|
|
}
|
|
Model::Perplexity => {
|
|
perplexity_prompt(
|
|
reqclient,
|
|
&api_keys.perplexity_api_key,
|
|
system_prompt,
|
|
context_prompt,
|
|
)
|
|
.await
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn perplexity_prompt(
|
|
reqclient: &reqwest::Client,
|
|
perplexity_api_key: &str,
|
|
system_prompt: &str,
|
|
context_prompt: &str,
|
|
) -> Result<reqwest::Response, reqwest::Error> {
|
|
let body = format!(
|
|
r#"{{
|
|
"model": "sonar",
|
|
"messages": [
|
|
{{
|
|
"role": "system",
|
|
"content": "{system_prompt}"
|
|
}},
|
|
{{
|
|
"role": "user",
|
|
"content": "{context_prompt}"
|
|
}}
|
|
],
|
|
"temperature": 0
|
|
}}"#,
|
|
system_prompt = system_prompt.replace('"', "\\\"").replace('\n', "\\n"),
|
|
context_prompt = context_prompt.replace('"', "\\\"").replace('\n', "\\n")
|
|
);
|
|
|
|
reqclient
|
|
.post("https://api.perplexity.ai/chat/completions")
|
|
.header("accept", "application/json")
|
|
.header("Content-Type", "application/json")
|
|
.header("Authorization", format!("Bearer {perplexity_api_key}"))
|
|
.body(body)
|
|
.send()
|
|
.await
|
|
}
|
|
|
|
async fn openai_prompt(
|
|
reqclient: &reqwest::Client,
|
|
openai_api_key: &str,
|
|
system_prompt: &str,
|
|
context_prompt: &str,
|
|
schema: &Option<String>,
|
|
) -> Result<reqwest::Response, reqwest::Error> {
|
|
let schema_wrapped_fmt = if let Some(s) = schema {
|
|
Some(format!(
|
|
r#"
|
|
"text": {{
|
|
"format": {{
|
|
"name": "testname",
|
|
"type": "json_schema",
|
|
"schema":
|
|
{s_fmt}
|
|
}}
|
|
}}
|
|
"#,
|
|
s_fmt = s
|
|
))
|
|
} else {
|
|
None
|
|
};
|
|
|
|
let body = format!(
|
|
r#"
|
|
{{
|
|
"model": "gpt-5-mini",
|
|
"input": [
|
|
{{
|
|
"role": "system",
|
|
"content": "{system_prompt}"
|
|
}},
|
|
{{
|
|
"role": "user",
|
|
"content": "{context_prompt}"
|
|
}}
|
|
],
|
|
"tools": [
|
|
{{ "type": "web_search" }}
|
|
]{opt_comma}
|
|
{schema_code}
|
|
}}
|
|
"#,
|
|
system_prompt = system_prompt.replace('"', "\\\"").replace('\n', "\\n"),
|
|
context_prompt = context_prompt.replace('"', "\\\"").replace('\n', "\\n"),
|
|
opt_comma = if schema_wrapped_fmt.is_some() {
|
|
","
|
|
} else {
|
|
""
|
|
},
|
|
schema_code = if let Some(s) = schema_wrapped_fmt {
|
|
s
|
|
} else {
|
|
String::new()
|
|
},
|
|
);
|
|
|
|
reqclient
|
|
.post("https://api.openai.com/v1/responses")
|
|
.header("Content-Type", "application/json")
|
|
.header("Authorization", format!("Bearer {openai_api_key}"))
|
|
.body(body)
|
|
.send()
|
|
.await
|
|
}
|
|
|
|
/// Implements \<QUERY\>\<QUERY\>
|
|
async fn gemini_prompt(
|
|
reqclient: &reqwest::Client,
|
|
gemini_api_key: &str,
|
|
system_prompt: &str,
|
|
context_prompt: &str,
|
|
) -> Result<reqwest::Response, reqwest::Error> {
|
|
let body = format!(
|
|
r#"{{
|
|
"contents": [
|
|
{{
|
|
"parts": [
|
|
{{
|
|
"text": "SYSTEM PROMPT:\n\n\n{system_prompt}\n\nContext Prompt:{context_prompt}\n\n\n\nSYSTEM PROMPT:\n\n\n{system_prompt}\n\nContext Prompt:{context_prompt}\n"
|
|
}}
|
|
]
|
|
}}
|
|
],
|
|
"tools": [
|
|
{{
|
|
"google_search": {{}}
|
|
}}
|
|
],
|
|
"generationConfig": {{
|
|
"temperature": 0.0,
|
|
}}
|
|
}}"#,
|
|
system_prompt = system_prompt.replace('"', "\\\"").replace('\n', "\\n"),
|
|
context_prompt = context_prompt.replace('"', "\\\"").replace('\n', "\\n")
|
|
);
|
|
|
|
reqclient
|
|
.post("https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent")
|
|
.header("x-goog-api-key", format!("{gemini_api_key}"))
|
|
.header("Content-Type", "application/json")
|
|
.body(body)
|
|
.send()
|
|
.await
|
|
}
|
|
|
|
async fn run(args: Args) -> Result<(), Error> {
|
|
let _ = dotenv::dotenv();
|
|
let mut vars = env::vars().collect::<HashMap<String, String>>();
|
|
let client = reqwest::Client::new();
|
|
|
|
let gemini_key = match vars.remove("GEMINI_API_KEY") {
|
|
Some(key) => key,
|
|
None => {
|
|
return Err(Error::ApiKeyNotFound(Model::Gemini));
|
|
}
|
|
};
|
|
let openai_key = match vars.remove("OPENAI_API_KEY") {
|
|
Some(key) => key,
|
|
None => {
|
|
return Err(Error::ApiKeyNotFound(Model::OpenAI));
|
|
}
|
|
};
|
|
let perplexity_key = match vars.remove("PERPLEXITY_API_KEY") {
|
|
Some(key) => key,
|
|
None => {
|
|
return Err(Error::ApiKeyNotFound(Model::Perplexity));
|
|
}
|
|
};
|
|
let api_keys = ApiKeys {
|
|
gemini_api_key: gemini_key,
|
|
openai_api_key: openai_key,
|
|
perplexity_api_key: perplexity_key,
|
|
};
|
|
|
|
match args.task {
|
|
Task::List { csv_file } => {
|
|
let pathbuf = PathBuf::from_str(&csv_file).unwrap_or_else(|_| unreachable!());
|
|
let csv = parse_csv(&pathbuf)?;
|
|
for (idx, company) in csv.iter().enumerate() {
|
|
println!(
|
|
"n = {:<4}\nCompany = {:<50?}\nWebsite: {:<30?}\nEmployee Count: {:<10?}\n",
|
|
idx + 1,
|
|
company.name,
|
|
company.website,
|
|
company.employee_count.as_deref().unwrap_or("unknown")
|
|
);
|
|
}
|
|
}
|
|
Task::CountPartners { json_file } => {
|
|
let pathbuf = PathBuf::from_str(&json_file).unwrap_or_else(|_| unreachable!());
|
|
let datafile = fs::read_to_string(json_file).unwrap();
|
|
let companies: Vec<FindPartnersTargetCompany> = serde_json::from_str(&datafile)
|
|
.map_err(|source| Error::SerdeJsonFile {
|
|
source,
|
|
filename: Some(pathbuf),
|
|
})?;
|
|
let mut partner_occurrences = HashMap::new();
|
|
for target_company in companies {
|
|
let connections_array = target_company.connections;
|
|
for connection in connections_array {
|
|
if connection.category == "REFERENCE_CLIENT" {
|
|
partner_occurrences
|
|
.entry(connection.partner_name)
|
|
.and_modify(|count| *count += 1)
|
|
.or_insert(1);
|
|
}
|
|
}
|
|
}
|
|
let mut count_vec: Vec<_> = partner_occurrences.iter().collect();
|
|
count_vec.sort_by(|a, b| b.1.cmp(a.1));
|
|
for (k, v) in count_vec {
|
|
println!("Company {:>60} appears {:>3} times", k, v);
|
|
}
|
|
}
|
|
Task::TestWithPandaloop {
|
|
model,
|
|
pretty,
|
|
system_prompt_file,
|
|
json_schema_file,
|
|
} => {
|
|
let system_prompt = fs::read_to_string(system_prompt_file)?;
|
|
let json_schema = if let Some(json) = json_schema_file {
|
|
let x = fs::read_to_string(json)?;
|
|
Some(x)
|
|
} else {
|
|
None
|
|
};
|
|
let context_prompt =
|
|
"Research Target Company T: Pandaloop\nWebsite: https://pandaloop.de\n".to_owned();
|
|
|
|
let mut response = prompt(
|
|
model,
|
|
&client,
|
|
&api_keys,
|
|
&system_prompt,
|
|
&context_prompt,
|
|
&json_schema,
|
|
)
|
|
.await?
|
|
.text()
|
|
.await?;
|
|
if pretty {
|
|
response = extract_response(&response, model)?;
|
|
response = validate_json(&response)?;
|
|
}
|
|
println!("{response}");
|
|
}
|
|
Task::SinglePrompt {
|
|
idx,
|
|
csv_file,
|
|
model,
|
|
system_prompt_file,
|
|
pretty,
|
|
json_schema_file,
|
|
} => {
|
|
let pathbuf = PathBuf::from_str(&csv_file).unwrap_or_else(|_| unreachable!());
|
|
let csv = parse_csv(&pathbuf)?;
|
|
let system_prompt = fs::read_to_string(system_prompt_file)?;
|
|
let json_schema = if let Some(json) = json_schema_file {
|
|
Some(fs::read_to_string(json)?)
|
|
} else {
|
|
None
|
|
};
|
|
|
|
if let Some(company) = csv.get(idx-1) {
|
|
let context_prompt = format!(
|
|
"Research Target Company: {}\nWebsite: {}\n",
|
|
company.name, company.website
|
|
);
|
|
let mut response = prompt(
|
|
model,
|
|
&client,
|
|
&api_keys,
|
|
&system_prompt,
|
|
&context_prompt,
|
|
&json_schema,
|
|
)
|
|
.await?
|
|
.text()
|
|
.await?;
|
|
if pretty {
|
|
response = extract_response(&response, model)?;
|
|
response = validate_json(&response)?;
|
|
}
|
|
println!("{response}");
|
|
} else {
|
|
eprintln!("Index {:?} out of range, skipping.", idx);
|
|
}
|
|
}
|
|
Task::RangePrompt {
|
|
range_start,
|
|
range_end,
|
|
csv_file,
|
|
model,
|
|
system_prompt_file,
|
|
pretty,
|
|
json_schema_file,
|
|
} => {
|
|
let pathbuf = PathBuf::from_str(&csv_file).unwrap_or_else(|_| unreachable!());
|
|
let csv = parse_csv(&pathbuf)?;
|
|
let system_prompt = fs::read_to_string(system_prompt_file)?;
|
|
let json_schema = if let Some(json) = json_schema_file {
|
|
Some(fs::read_to_string(json)?)
|
|
} else {
|
|
None
|
|
};
|
|
|
|
if range_start > range_end {
|
|
return Err(Error::InvalidRange(format!(
|
|
"range args {range_start}..{range_end} are invalid: range_start should be smaller than range_end"
|
|
)));
|
|
}
|
|
let mut context_prompt = String::new();
|
|
for i in range_start..=range_end {
|
|
if let Some(company) = csv.get(i-1) {
|
|
context_prompt = format!(
|
|
"{}\nResearch Target Company T: {}\nWebsite: {}\nEmployee Count: {}\n\n",
|
|
context_prompt,
|
|
company.name,
|
|
company.website,
|
|
company.employee_count.as_deref().unwrap_or("unknown")
|
|
);
|
|
} else {
|
|
eprintln!("Index {:?} out of range, skipping.", i);
|
|
}
|
|
}
|
|
let mut response = prompt(
|
|
model,
|
|
&client,
|
|
&api_keys,
|
|
&system_prompt,
|
|
&context_prompt,
|
|
&json_schema,
|
|
)
|
|
.await?
|
|
.text()
|
|
.await?;
|
|
if pretty {
|
|
response = extract_response(&response, model)?;
|
|
response = validate_json(&response)?;
|
|
}
|
|
println!("{response}");
|
|
}
|
|
Task::MultiPrompt {
|
|
range_start,
|
|
range_end,
|
|
range_step,
|
|
csv_file,
|
|
model,
|
|
system_prompt_file,
|
|
pretty,
|
|
json_schema_file,
|
|
} => {
|
|
let pathbuf = PathBuf::from_str(&csv_file).unwrap_or_else(|_| unreachable!());
|
|
let csv = parse_csv(&pathbuf)?;
|
|
let system_prompt = fs::read_to_string(system_prompt_file)?;
|
|
let json_schema = if let Some(json) = json_schema_file {
|
|
Some(fs::read_to_string(json)?)
|
|
} else {
|
|
None
|
|
};
|
|
|
|
if range_start > range_end {
|
|
return Err(Error::InvalidRange(format!(
|
|
"range args {range_start}..{range_end} are invalid: range_start should be smaller than range_end"
|
|
)));
|
|
}
|
|
let mut i = range_start;
|
|
while i <= range_end {
|
|
let mut context_prompt = String::new();
|
|
for j in i..i + range_step {
|
|
if let Some(company) = csv.get(j-1) {
|
|
context_prompt = format!(
|
|
"{}\nResearch Target Company T: {}\nWebsite: {}\n\n\n",
|
|
context_prompt, company.name, company.website
|
|
);
|
|
} else {
|
|
eprintln!("Index {:?} out of range, skipping.", j);
|
|
}
|
|
}
|
|
eprintln!(
|
|
"{style_bold}{color_bright_blue}Dispatching prompt for range {}..={}{style_reset}{color_reset}",
|
|
i,
|
|
i + range_step - 1
|
|
);
|
|
let response_result = prompt(
|
|
model,
|
|
&client,
|
|
&api_keys,
|
|
&system_prompt,
|
|
&context_prompt,
|
|
&json_schema,
|
|
)
|
|
.await;
|
|
|
|
match response_result {
|
|
Ok(response) => {
|
|
let response_status = response.status();
|
|
let response_text = response.text().await?;
|
|
match response_status {
|
|
reqwest::StatusCode::OK => {
|
|
if pretty {
|
|
match extract_response(&response_text, model) {
|
|
Ok(response_text_pretty) => {
|
|
println!("{response_text_pretty}");
|
|
i += range_step;
|
|
}
|
|
Err(e) => {
|
|
return Err(e);
|
|
}
|
|
}
|
|
} else {
|
|
println!("{response_text}");
|
|
i += range_step;
|
|
}
|
|
}
|
|
reqwest::StatusCode::TOO_MANY_REQUESTS => {
|
|
eprintln!("{style_bold}{color_bright_blue}429 reached. Go to bed.{color_reset}{style_reset}");
|
|
return Err(Error::APIResponseError(response_text))
|
|
}
|
|
reqwest::StatusCode::NOT_FOUND => {
|
|
return Err(Error::APIResponseError(response_text))
|
|
}
|
|
_ => {
|
|
eprintln!("{style_bold}{color_bright_red}Response status != 200:{color_reset}{style_reset}\n{response_text}\n{style_bold}{color_bright_blue}Retrying in 5 seconds...{color_reset}{style_reset}");
|
|
time::sleep(time::Duration::from_millis(5000)).await;
|
|
}
|
|
}
|
|
|
|
}
|
|
Err(e) => {
|
|
return Err(Error::Reqwest(e));
|
|
}
|
|
}
|
|
|
|
if model == Model::OpenAI {
|
|
time::sleep(time::Duration::from_millis(60000)).await;
|
|
}
|
|
}
|
|
}
|
|
Task::ProcessLeadEvalJson {
|
|
json_file,
|
|
output_format,
|
|
min_attractiveness,
|
|
} => {
|
|
let raw_json = fs::read_to_string(json_file)?;
|
|
let heuristic = raw_json
|
|
.replace("]\n[\n", "")
|
|
.replace("}\n{", "},\n{")
|
|
.replace("}\n {", "},\n {");
|
|
let parsed: Vec<EvalLeadsLead> = serde_json::from_str(&heuristic)?;
|
|
match output_format {
|
|
OutputFormat::Json => {
|
|
let serialized = serde_json::to_string_pretty(&parsed)?;
|
|
println!("{serialized}");
|
|
return Ok(());
|
|
}
|
|
OutputFormat::Typst => {
|
|
let mut typcode = String::new();
|
|
write!(
|
|
typcode,
|
|
r#"#set table(inset: 6pt)
|
|
#set text(font: "Geist")
|
|
#set page(margin: (left: 1cm, right: 1cm))
|
|
#set heading(numbering: "1.")
|
|
#outline()
|
|
#pagebreak()
|
|
|
|
"#
|
|
)
|
|
.unwrap_or_else(|_| unreachable!());
|
|
for lead in parsed {
|
|
if lead.lead_attractiveness_score < min_attractiveness {
|
|
continue;
|
|
}
|
|
write!(
|
|
typcode,
|
|
r#"#table(
|
|
columns: (1fr, 0.5fr),
|
|
fill: luma(230),
|
|
table.cell(stroke: (right: none), [= {name}]),
|
|
table.cell(stroke: (left: none), align(right, [{domain}])),
|
|
[
|
|
{desc}
|
|
|
|
{industry}
|
|
],
|
|
[Teamgröße: {empl}],
|
|
table.cell(colspan: 2, [
|
|
*Allgemeine Kontakte*:
|
|
"#,
|
|
name = lead
|
|
.company_name
|
|
.replace("*", "\\*")
|
|
.replace("_", "\\_")
|
|
.replace(r"@", r"\@")
|
|
.replace(r"<", r"\<"),
|
|
domain = lead
|
|
.website
|
|
.replace("*", "\\*")
|
|
.replace("_", "\\_")
|
|
.replace(r"@", r"\@")
|
|
.replace(r"<", r"\<"),
|
|
desc = lead
|
|
.description
|
|
.replace("*", "\\*")
|
|
.replace("_", "\\_")
|
|
.replace(r"@", r"\@")
|
|
.replace(r"<", r"\<"),
|
|
industry = lead
|
|
.industry
|
|
.replace("*", "\\*")
|
|
.replace("_", "\\_")
|
|
.replace(r"@", r"\@")
|
|
.replace(r"<", r"\<"),
|
|
empl = lead
|
|
.employee_count
|
|
.replace("*", "\\*")
|
|
.replace("_", "\\_")
|
|
.replace(r"@", r"\@")
|
|
.replace(r"<", r"\<"),
|
|
)
|
|
.unwrap_or_else(|_| unreachable!());
|
|
|
|
for contact in lead.general_contacts {
|
|
write!(
|
|
typcode,
|
|
r#" - {val} (Art / Abteilung: {type})
|
|
"#,
|
|
val = contact.value.replace(r"*", r"\*").replace(r"@", r"\@").replace(r"_", r"\_"),
|
|
type = {
|
|
match contact.category.as_str() {
|
|
"SALES_DIRECT" => "Sales",
|
|
"GENERAL_INFO" => "Allgemein",
|
|
"SUPPORT" => "Support / Hotline",
|
|
"PRESS_MARKETING" => "Presse / Marketing",
|
|
"OTHER" => "Unbekannt",
|
|
_ => &contact.value_type
|
|
}
|
|
}
|
|
)
|
|
.unwrap_or_else(|_| unreachable!());
|
|
}
|
|
write!(
|
|
typcode,
|
|
r#"]),
|
|
table.cell(colspan: 2, [
|
|
*Mitarbeiter*:
|
|
#table(
|
|
stroke: (paint: luma(100), thickness: 1pt, dash: "dashed"),
|
|
columns: (1fr, auto, auto),
|
|
"#
|
|
)
|
|
.unwrap_or_else(|_| unreachable!());
|
|
for empl in lead.employees {
|
|
write!(
|
|
typcode,
|
|
r#" "{name}, {role}", link("mailto:{mail}", "{mail}"), link("tel:{tel}", "{tel}"),
|
|
"#,
|
|
name = empl.name.replace("*", "\\*").replace("_", "\\_"),
|
|
role = empl.role.replace("*", "\\*").replace("_", "\\_"),
|
|
mail = {
|
|
let x = if empl.email.is_empty() || empl.email == "null" {
|
|
"-".to_owned()
|
|
} else {
|
|
empl.email.replace("*", "\\*").replace("_", "\\_")
|
|
};
|
|
x
|
|
},
|
|
tel = {
|
|
let x = if empl.phone.is_empty() || empl.phone == "null" {
|
|
"-".to_owned()
|
|
} else {
|
|
empl.phone.replace("*", "\\*").replace("_", "\\_")
|
|
};
|
|
x
|
|
}
|
|
)
|
|
.unwrap_or_else(|_| unreachable!());
|
|
}
|
|
write!(
|
|
typcode,
|
|
r#" )
|
|
]),
|
|
)
|
|
"#
|
|
)
|
|
.unwrap_or_else(|_| unreachable!());
|
|
}
|
|
println!("{typcode}");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
|
|
// - denkwerk Partner: *fraenk*, *OTTO*, *SWR*, *Union Investment*, *ARAG*, erenja, BurdaForward, Stiebel Eltron, edding, polymore, Telekom, Santander, DeepL, motel one, Struggly, Fondation Beyeler, Charite, mainova, Esprit, Sport Cast, Remondis, condor, , Aktion Mensch, easy credit, Sparkasse KölnBonn, multiloop, Storck, Teambank, Ranger, Microsoft
|
|
}
|
|
|
|
#[tokio::main]
|
|
async fn main() -> ExitCode {
|
|
let args = Args::parse();
|
|
match run(args).await {
|
|
Ok(_) => ExitCode::SUCCESS,
|
|
Err(e) => {
|
|
eprintln!("{e}");
|
|
ExitCode::FAILURE
|
|
}
|
|
}
|
|
}
|